• Stars
    star
    222
  • Rank 178,063 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 6 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A wrapper layer for stacking layers horizontally

Keras Multi-Head

Version License

A wrapper layer for stacking layers horizontally.

Install

pip install keras-multi-head

Usage

Duplicate Layers

The layer will be duplicated if only a single layer is provided. The layer_num argument controls how many layers will be duplicated eventually.

from tensorflow import keras
from keras_multi_head import MultiHead


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=100, output_dim=20, name='Embedding'))
model.add(MultiHead(keras.layers.LSTM(units=32), layer_num=5, name='Multi-LSTMs'))
model.add(keras.layers.Flatten(name='Flatten'))
model.add(keras.layers.Dense(units=4, activation='softmax', name='Dense'))
model.build()
model.summary()

Use Multiple-Layers

The first argument could also be a list of layers with different configurations, however, they must have the same output shapes.

from tensorflow import keras
from keras_multi_head import MultiHead


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=100, output_dim=20, name='Embedding'))
model.add(MultiHead([
    keras.layers.Conv1D(filters=32, kernel_size=3, padding='same'),
    keras.layers.Conv1D(filters=32, kernel_size=5, padding='same'),
    keras.layers.Conv1D(filters=32, kernel_size=7, padding='same'),
], name='Multi-CNNs'))
model.build()
model.summary()

Linear Transformation

The input data will be mapped to different values of the same shape for each layer when hidden_dim is given.

Regularization

The regularization is used when you expect to extract different features from the parallel layers. You can customize the indices of weights in the layers, the intervals represent the parts of the weights and the factor of the regularization.

For example, the bidirectional LSTM layer has 6 weights by default, and the first 3s belong to the forward layer. The 2nd weight (recurrent kernel) in the forward layer controls the computation of gates for recurrent connections. The kernel for computing cell states lays in units x 2 to units x 3 of the recurrent kernel. We can used the regularization for the kernels:

from tensorflow import keras
from keras_multi_head import MultiHead


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=5, output_dim=3, name='Embed'))
model.add(MultiHead(
    layer=keras.layers.Bidirectional(keras.layers.LSTM(units=16), name='LSTM'),
    layer_num=5,
    reg_index=[1, 4],
    reg_slice=(slice(None, None), slice(32, 48)),
    reg_factor=0.1,
    name='Multi-Head-Attention',
))
model.add(keras.layers.Flatten(name='Flatten'))
model.add(keras.layers.Dense(units=2, activation='softmax', name='Dense'))
model.build()
  • reg_index: The indices of layer.get_weights(), a single integer or a list of integers.
  • reg_slice: slices or a tuple of slices or a list of the previous choices. If multiple indices are provided in reg_index and reg_slice is not a list, then reg_slice is assumed to be equal for all the indices. The whole array will be used if you leave this argument to None.
  • reg_factor: The factor of the regularization, a float or a list of floats.

Multi-Head Attention

A more specific multi-head layer is provided (since the general one is harder to use). The layer uses scaled dot product attention layers as its sub-layers and only head_num is required:

from tensorflow import keras
from keras_multi_head import MultiHeadAttention

input_layer = keras.layers.Input(
    shape=(2, 3),
    name='Input',
)
att_layer = MultiHeadAttention(
    head_num=3,
    name='Multi-Head',
)(input_layer)
model = keras.models.Model(inputs=input_layer, outputs=att_layer)
model.compile(
    optimizer='adam',
    loss='mse',
    metrics={},
)
model.summary()

The shapes of input and output tensors would be the same if only one layer is presented as input. The input layers will be considered as query, key and value when a list is given:

from tensorflow import keras
from keras_multi_head import MultiHeadAttention

input_query = keras.layers.Input(
    shape=(2, 3),
    name='Input-Q',
)
input_key = keras.layers.Input(
    shape=(4, 5),
    name='Input-K',
)
input_value = keras.layers.Input(
    shape=(4, 6),
    name='Input-V',
)
att_layer = MultiHeadAttention(
    head_num=3,
    name='Multi-Head',
)([input_query, input_key, input_value])
model = keras.models.Model(inputs=[input_query, input_key, input_value], outputs=att_layer)
model.compile(
    optimizer='adam',
    loss='mse',
    metrics={},
)
model.summary()

More Repositories

1

keras-bert

Implementation of BERT that could load official pre-trained models for feature extraction and prediction
Python
2,411
star
2

toolbox

https://cyberzhg.github.io/toolbox/ Encoding and parsing tools.
JavaScript
842
star
3

keras-self-attention

Attention mechanism for processing sequential data that considers the context for each timestamp.
Python
641
star
4

CLRS

Some exercises and problems in Introduction to Algorithms 3rd edition.
Jupyter Notebook
392
star
5

keras-transformer

Transformer implemented in Keras
Python
356
star
6

keras-radam

RAdam implemented in Keras & TensorFlow
Python
326
star
7

keras-xlnet

Implementation of XLNet that can load pretrained checkpoints
Python
171
star
8

keras-gpt-2

Load GPT-2 checkpoint and generate texts
Python
127
star
9

torch-multi-head-attention

Multi-head attention in PyTorch
Python
125
star
10

keras-transformer-xl

Transformer-XL with checkpoint loader
Python
68
star
11

keras-pos-embd

Position embedding layers in Keras
Python
62
star
12

keras-gcn

Graph convolutional layers
Python
61
star
13

keras-layer-normalization

Layer normalization implemented in Keras
Python
60
star
14

keras-adabound

AdaBound optimizer in Keras
Python
57
star
15

keras-lookahead

Lookahead mechanism for optimizers in Keras.
Python
50
star
16

keras-word-char-embd

Concatenate word and character embeddings in Keras
Python
46
star
17

keras-lr-multiplier

Learning rate multiplier
Python
46
star
18

keras-octave-conv

Octave convolution
Python
36
star
19

keras-gradient-accumulation

Gradient accumulation for Keras
Python
35
star
20

keras-ordered-neurons

Ordered Neurons LSTM
Python
30
star
21

keras-drop-block

DropBlock implemented in Keras
Python
25
star
22

wiki-dump-reader

Extract corpora from Wikipedia dumps
Python
21
star
23

torch-layer-normalization

Layer normalization in PyTorch
Python
18
star
24

keras-adaptive-softmax

Adaptive embedding and softmax
Python
17
star
25

tf-keras-kervolution-2d

Kervolutional neural networks
Python
16
star
26

keras-trans-mask

Remove and restore masks for layers that do not support masking
Python
16
star
27

keras-lamb

Layer-wise Adaptive Moments optimizer for Batch training
Python
15
star
28

torch-position-embedding

Position embedding in PyTorch
Python
14
star
29

keras-losses

Some loss functions in Keras
Python
10
star
30

keras-embed-sim

Calculate similarity with embedding
Python
10
star
31

keras-position-wise-feed-forward

Feed forward layer implemented in Keras
Python
8
star
32

keras-targeted-dropout

Targeted dropout implemented in Keras
Python
8
star
33

EmojiView

😼 EmojiView for Android.
Java
8
star
34

github-action-python-lint

GitHub action that runs pycodestyle
Dockerfile
7
star
35

LaTeXGitHubMarkdown

Show LaTeX formulas for GitHub Markdown files.
JavaScript
7
star
36

keras-drop-connect

Drop-connect wrapper
Python
7
star
37

torch-gpt-2

Load GPT-2 checkpoint and generate texts in PyTorch
Python
6
star
38

keras-conv-vis

Convolution visualization
Python
6
star
39

MIXAL

MIX Assembly Language Simulator
C++
6
star
40

Sketch-Based

Some implementations of sketch-based methods; no longer maintained.
MATLAB
6
star
41

MineForces

Codeforces problem filter.
JavaScript
5
star
42

github-action-cpp-lint

GitHub action that runs cpplint
Dockerfile
5
star
43

keras-bi-lm

Train the Bi-LM model and use it as a feature extraction method
Python
5
star
44

mxnet-octave-conv

Octave convolution
Python
3
star
45

torch-same-pad

Paddings used for converting TensorFlow conv/pool layers to PyTorch.
Python
3
star
46

gitbook-plugin-meta

Add meta data to <head> for your gitbook.
HTML
3
star
47

toy-auto-diff

Toy implementation of automatic differentiation
Python
3
star
48

keras-piecewise-pooling

Piecewise pooling layer in Keras
Python
2
star
49

keras-piecewise

A wrapper layer for splitting and accumulating sequential data.
Python
2
star
50

keras-perturbation

A demonstration of perturbation of data
Python
2
star
51

parse-toys

Parsing toys
Python
2
star
52

swift-6502-core

Emulation of the 6502 CPU
Swift
2
star
53

github-action-python-test

GitHub action that runs nose tests
Dockerfile
2
star
54

CrimsonTomato

https://goo.gl/JpF6eP Pomodoro timer, sync to calendar.
Java
2
star
55

keras-succ-reg-wrapper

A wrapper that slows down the updates of trainable weights.
Python
1
star
56

torch-embed-sim

Embedding similarity in PyTorch
Python
1
star
57

torch-transformer

Transformer in PyTorch
Python
1
star
58

UChar

Basic unicode information about a character.
C++
1
star
59

CppTesting

Personal C++ testing framework.
C++
1
star