• Stars
    star
    641
  • Rank 69,716 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created about 6 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Attention mechanism for processing sequential data that considers the context for each timestamp.

Keras Self-Attention

Version License

[δΈ­ζ–‡|English]

Attention mechanism for processing sequential data that considers the context for each timestamp.

Install

pip install keras-self-attention

Usage

Basic

By default, the attention layer uses additive attention and considers the whole context while calculating the relevance. The following code creates an attention layer that follows the equations in the first section (attention_activation is the activation function of e_{t, t'}):

from tensorflow import keras
from keras_self_attention import SeqSelfAttention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=10000,
                                 output_dim=300,
                                 mask_zero=True))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,
                                                       return_sequences=True)))
model.add(SeqSelfAttention(attention_activation='sigmoid'))
model.add(keras.layers.Dense(units=5))
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['categorical_accuracy'],
)
model.summary()

Local Attention

The global context may be too broad for one piece of data. The parameter attention_width controls the width of the local context:

from keras_self_attention import SeqSelfAttention

SeqSelfAttention(
    attention_width=15,
    attention_activation='sigmoid',
    name='Attention',
)

Multiplicative Attention

You can use multiplicative attention by setting attention_type:

from keras_self_attention import SeqSelfAttention

SeqSelfAttention(
    attention_width=15,
    attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
    attention_activation=None,
    kernel_regularizer=keras.regularizers.l2(1e-6),
    use_attention_bias=False,
    name='Attention',
)

Regularizer

To use the regularizer, set attention_regularizer_weight to a positive number:

from tensorflow import keras
from keras_self_attention import SeqSelfAttention

inputs = keras.layers.Input(shape=(None,))
embd = keras.layers.Embedding(input_dim=32,
                              output_dim=16,
                              mask_zero=True)(inputs)
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=16,
                                                    return_sequences=True))(embd)
att = SeqSelfAttention(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,
                       kernel_regularizer=keras.regularizers.l2(1e-4),
                       bias_regularizer=keras.regularizers.l1(1e-4),
                       attention_regularizer_weight=1e-4,
                       name='Attention')(lstm)
dense = keras.layers.Dense(units=5, name='Dense')(att)
model = keras.models.Model(inputs=inputs, outputs=[dense])
model.compile(
    optimizer='adam',
    loss={'Dense': 'sparse_categorical_crossentropy'},
    metrics={'Dense': 'categorical_accuracy'},
)
model.summary(line_length=100)

Load the Model

Make sure to add SeqSelfAttention to custom objects:

from tensorflow import keras

keras.models.load_model(model_path, custom_objects=SeqSelfAttention.get_custom_objects())

History Only

Set history_only to True when only historical data could be used:

SeqSelfAttention(
    attention_width=3,
    history_only=True,
    name='Attention',
)

Multi-Head

Please refer to keras-multi-head.

More Repositories

1

keras-bert

Implementation of BERT that could load official pre-trained models for feature extraction and prediction
Python
2,411
star
2

toolbox

https://cyberzhg.github.io/toolbox/ Encoding and parsing tools.
JavaScript
842
star
3

CLRS

Some exercises and problems in Introduction to Algorithms 3rd edition.
Jupyter Notebook
392
star
4

keras-transformer

Transformer implemented in Keras
Python
356
star
5

keras-radam

RAdam implemented in Keras & TensorFlow
Python
326
star
6

keras-multi-head

A wrapper layer for stacking layers horizontally
Python
222
star
7

keras-xlnet

Implementation of XLNet that can load pretrained checkpoints
Python
171
star
8

keras-gpt-2

Load GPT-2 checkpoint and generate texts
Python
127
star
9

torch-multi-head-attention

Multi-head attention in PyTorch
Python
125
star
10

keras-transformer-xl

Transformer-XL with checkpoint loader
Python
68
star
11

keras-pos-embd

Position embedding layers in Keras
Python
62
star
12

keras-gcn

Graph convolutional layers
Python
61
star
13

keras-layer-normalization

Layer normalization implemented in Keras
Python
60
star
14

keras-adabound

AdaBound optimizer in Keras
Python
57
star
15

keras-lookahead

Lookahead mechanism for optimizers in Keras.
Python
50
star
16

keras-word-char-embd

Concatenate word and character embeddings in Keras
Python
46
star
17

keras-lr-multiplier

Learning rate multiplier
Python
46
star
18

keras-octave-conv

Octave convolution
Python
36
star
19

keras-gradient-accumulation

Gradient accumulation for Keras
Python
35
star
20

keras-ordered-neurons

Ordered Neurons LSTM
Python
30
star
21

keras-drop-block

DropBlock implemented in Keras
Python
25
star
22

wiki-dump-reader

Extract corpora from Wikipedia dumps
Python
21
star
23

torch-layer-normalization

Layer normalization in PyTorch
Python
18
star
24

keras-adaptive-softmax

Adaptive embedding and softmax
Python
17
star
25

tf-keras-kervolution-2d

Kervolutional neural networks
Python
16
star
26

keras-trans-mask

Remove and restore masks for layers that do not support masking
Python
16
star
27

keras-lamb

Layer-wise Adaptive Moments optimizer for Batch training
Python
15
star
28

torch-position-embedding

Position embedding in PyTorch
Python
14
star
29

keras-losses

Some loss functions in Keras
Python
10
star
30

keras-embed-sim

Calculate similarity with embedding
Python
10
star
31

keras-position-wise-feed-forward

Feed forward layer implemented in Keras
Python
8
star
32

keras-targeted-dropout

Targeted dropout implemented in Keras
Python
8
star
33

EmojiView

😼 EmojiView for Android.
Java
8
star
34

github-action-python-lint

GitHub action that runs pycodestyle
Dockerfile
7
star
35

LaTeXGitHubMarkdown

Show LaTeX formulas for GitHub Markdown files.
JavaScript
7
star
36

keras-drop-connect

Drop-connect wrapper
Python
7
star
37

torch-gpt-2

Load GPT-2 checkpoint and generate texts in PyTorch
Python
6
star
38

keras-conv-vis

Convolution visualization
Python
6
star
39

MIXAL

MIX Assembly Language Simulator
C++
6
star
40

Sketch-Based

Some implementations of sketch-based methods; no longer maintained.
MATLAB
6
star
41

MineForces

Codeforces problem filter.
JavaScript
5
star
42

github-action-cpp-lint

GitHub action that runs cpplint
Dockerfile
5
star
43

keras-bi-lm

Train the Bi-LM model and use it as a feature extraction method
Python
5
star
44

mxnet-octave-conv

Octave convolution
Python
3
star
45

torch-same-pad

Paddings used for converting TensorFlow conv/pool layers to PyTorch.
Python
3
star
46

gitbook-plugin-meta

Add meta data to <head> for your gitbook.
HTML
3
star
47

toy-auto-diff

Toy implementation of automatic differentiation
Python
3
star
48

keras-piecewise-pooling

Piecewise pooling layer in Keras
Python
2
star
49

keras-piecewise

A wrapper layer for splitting and accumulating sequential data.
Python
2
star
50

keras-perturbation

A demonstration of perturbation of data
Python
2
star
51

parse-toys

Parsing toys
Python
2
star
52

swift-6502-core

Emulation of the 6502 CPU
Swift
2
star
53

github-action-python-test

GitHub action that runs nose tests
Dockerfile
2
star
54

CrimsonTomato

https://goo.gl/JpF6eP Pomodoro timer, sync to calendar.
Java
2
star
55

keras-succ-reg-wrapper

A wrapper that slows down the updates of trainable weights.
Python
1
star
56

torch-embed-sim

Embedding similarity in PyTorch
Python
1
star
57

torch-transformer

Transformer in PyTorch
Python
1
star
58

UChar

Basic unicode information about a character.
C++
1
star
59

CppTesting

Personal C++ testing framework.
C++
1
star