• Stars
    star
    356
  • Rank 119,446 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 6 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Transformer implemented in Keras

Keras Transformer

Version License

[中文|English]

Implementation of transformer for seq2seq tasks.

Install

pip install keras-transformer

Usage

Train

import numpy as np
from keras_transformer import get_model

# Build a small toy token dictionary
tokens = 'all work and no play makes jack a dull boy'.split(' ')
token_dict = {
    '<PAD>': 0,
    '<START>': 1,
    '<END>': 2,
}
for token in tokens:
    if token not in token_dict:
        token_dict[token] = len(token_dict)

# Generate toy data
encoder_inputs_no_padding = []
encoder_inputs, decoder_inputs, decoder_outputs = [], [], []
for i in range(1, len(tokens) - 1):
    encode_tokens, decode_tokens = tokens[:i], tokens[i:]
    encode_tokens = ['<START>'] + encode_tokens + ['<END>'] + ['<PAD>'] * (len(tokens) - len(encode_tokens))
    output_tokens = decode_tokens + ['<END>', '<PAD>'] + ['<PAD>'] * (len(tokens) - len(decode_tokens))
    decode_tokens = ['<START>'] + decode_tokens + ['<END>'] + ['<PAD>'] * (len(tokens) - len(decode_tokens))
    encode_tokens = list(map(lambda x: token_dict[x], encode_tokens))
    decode_tokens = list(map(lambda x: token_dict[x], decode_tokens))
    output_tokens = list(map(lambda x: [token_dict[x]], output_tokens))
    encoder_inputs_no_padding.append(encode_tokens[:i + 2])
    encoder_inputs.append(encode_tokens)
    decoder_inputs.append(decode_tokens)
    decoder_outputs.append(output_tokens)

# Build the model
model = get_model(
    token_num=len(token_dict),
    embed_dim=30,
    encoder_num=3,
    decoder_num=2,
    head_num=3,
    hidden_dim=120,
    attention_activation='relu',
    feed_forward_activation='relu',
    dropout_rate=0.05,
    embed_weights=np.random.random((13, 30)),
)
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
)
model.summary()

# Train the model
model.fit(
    x=[np.asarray(encoder_inputs * 1000), np.asarray(decoder_inputs * 1000)],
    y=np.asarray(decoder_outputs * 1000),
    epochs=5,
)

Predict

from keras_transformer import decode

decoded = decode(
    model,
    encoder_inputs_no_padding,
    start_token=token_dict['<START>'],
    end_token=token_dict['<END>'],
    pad_token=token_dict['<PAD>'],
    max_len=100,
)
token_dict_rev = {v: k for k, v in token_dict.items()}
for i in range(len(decoded)):
    print(' '.join(map(lambda x: token_dict_rev[x], decoded[i][1:-1])))

Translation

import numpy as np
from keras_transformer import get_model, decode

source_tokens = [
    'i need more power'.split(' '),
    'eat jujube and pill'.split(' '),
]
target_tokens = [
    list('我要更多的抛瓦'),
    list('吃枣💊'),
]

# Generate dictionaries
def build_token_dict(token_list):
    token_dict = {
        '<PAD>': 0,
        '<START>': 1,
        '<END>': 2,
    }
    for tokens in token_list:
        for token in tokens:
            if token not in token_dict:
                token_dict[token] = len(token_dict)
    return token_dict

source_token_dict = build_token_dict(source_tokens)
target_token_dict = build_token_dict(target_tokens)
target_token_dict_inv = {v: k for k, v in target_token_dict.items()}

# Add special tokens
encode_tokens = [['<START>'] + tokens + ['<END>'] for tokens in source_tokens]
decode_tokens = [['<START>'] + tokens + ['<END>'] for tokens in target_tokens]
output_tokens = [tokens + ['<END>', '<PAD>'] for tokens in target_tokens]

# Padding
source_max_len = max(map(len, encode_tokens))
target_max_len = max(map(len, decode_tokens))

encode_tokens = [tokens + ['<PAD>'] * (source_max_len - len(tokens)) for tokens in encode_tokens]
decode_tokens = [tokens + ['<PAD>'] * (target_max_len - len(tokens)) for tokens in decode_tokens]
output_tokens = [tokens + ['<PAD>'] * (target_max_len - len(tokens)) for tokens in output_tokens]

encode_input = [list(map(lambda x: source_token_dict[x], tokens)) for tokens in encode_tokens]
decode_input = [list(map(lambda x: target_token_dict[x], tokens)) for tokens in decode_tokens]
decode_output = [list(map(lambda x: [target_token_dict[x]], tokens)) for tokens in output_tokens]

# Build & fit model
model = get_model(
    token_num=max(len(source_token_dict), len(target_token_dict)),
    embed_dim=32,
    encoder_num=2,
    decoder_num=2,
    head_num=4,
    hidden_dim=128,
    dropout_rate=0.05,
    use_same_embed=False,  # Use different embeddings for different languages
)
model.compile('adam', 'sparse_categorical_crossentropy')
model.summary()

model.fit(
    x=[np.array(encode_input * 1024), np.array(decode_input * 1024)],
    y=np.array(decode_output * 1024),
    epochs=10,
    batch_size=32,
)

# Predict
decoded = decode(
    model,
    encode_input,
    start_token=target_token_dict['<START>'],
    end_token=target_token_dict['<END>'],
    pad_token=target_token_dict['<PAD>'],
)
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))

Decode

In decode, the word with top probability is selected as the predicted token by default. You can add randomness by setting top_k and temperature:

decoded = decode(
    model,
    encode_input,
    start_token=target_token_dict['<START>'],
    end_token=target_token_dict['<END>'],
    pad_token=target_token_dict['<PAD>'],
    top_k=10,
    temperature=1.0,
)
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[0][1:-1])))
print(''.join(map(lambda x: target_token_dict_inv[x], decoded[1][1:-1])))

More Repositories

1

keras-bert

Implementation of BERT that could load official pre-trained models for feature extraction and prediction
Python
2,411
star
2

toolbox

https://cyberzhg.github.io/toolbox/ Encoding and parsing tools.
JavaScript
842
star
3

keras-self-attention

Attention mechanism for processing sequential data that considers the context for each timestamp.
Python
641
star
4

CLRS

Some exercises and problems in Introduction to Algorithms 3rd edition.
Jupyter Notebook
392
star
5

keras-radam

RAdam implemented in Keras & TensorFlow
Python
326
star
6

keras-multi-head

A wrapper layer for stacking layers horizontally
Python
222
star
7

keras-xlnet

Implementation of XLNet that can load pretrained checkpoints
Python
171
star
8

keras-gpt-2

Load GPT-2 checkpoint and generate texts
Python
127
star
9

torch-multi-head-attention

Multi-head attention in PyTorch
Python
125
star
10

keras-transformer-xl

Transformer-XL with checkpoint loader
Python
68
star
11

keras-pos-embd

Position embedding layers in Keras
Python
62
star
12

keras-gcn

Graph convolutional layers
Python
61
star
13

keras-layer-normalization

Layer normalization implemented in Keras
Python
60
star
14

keras-adabound

AdaBound optimizer in Keras
Python
57
star
15

keras-lookahead

Lookahead mechanism for optimizers in Keras.
Python
50
star
16

keras-word-char-embd

Concatenate word and character embeddings in Keras
Python
46
star
17

keras-lr-multiplier

Learning rate multiplier
Python
46
star
18

keras-octave-conv

Octave convolution
Python
36
star
19

keras-gradient-accumulation

Gradient accumulation for Keras
Python
35
star
20

keras-ordered-neurons

Ordered Neurons LSTM
Python
30
star
21

keras-drop-block

DropBlock implemented in Keras
Python
25
star
22

wiki-dump-reader

Extract corpora from Wikipedia dumps
Python
21
star
23

torch-layer-normalization

Layer normalization in PyTorch
Python
18
star
24

keras-adaptive-softmax

Adaptive embedding and softmax
Python
17
star
25

tf-keras-kervolution-2d

Kervolutional neural networks
Python
16
star
26

keras-trans-mask

Remove and restore masks for layers that do not support masking
Python
16
star
27

keras-lamb

Layer-wise Adaptive Moments optimizer for Batch training
Python
15
star
28

torch-position-embedding

Position embedding in PyTorch
Python
14
star
29

keras-losses

Some loss functions in Keras
Python
10
star
30

keras-embed-sim

Calculate similarity with embedding
Python
10
star
31

keras-position-wise-feed-forward

Feed forward layer implemented in Keras
Python
8
star
32

keras-targeted-dropout

Targeted dropout implemented in Keras
Python
8
star
33

EmojiView

😼 EmojiView for Android.
Java
8
star
34

github-action-python-lint

GitHub action that runs pycodestyle
Dockerfile
7
star
35

LaTeXGitHubMarkdown

Show LaTeX formulas for GitHub Markdown files.
JavaScript
7
star
36

keras-drop-connect

Drop-connect wrapper
Python
7
star
37

torch-gpt-2

Load GPT-2 checkpoint and generate texts in PyTorch
Python
6
star
38

keras-conv-vis

Convolution visualization
Python
6
star
39

MIXAL

MIX Assembly Language Simulator
C++
6
star
40

Sketch-Based

Some implementations of sketch-based methods; no longer maintained.
MATLAB
6
star
41

MineForces

Codeforces problem filter.
JavaScript
5
star
42

github-action-cpp-lint

GitHub action that runs cpplint
Dockerfile
5
star
43

keras-bi-lm

Train the Bi-LM model and use it as a feature extraction method
Python
5
star
44

mxnet-octave-conv

Octave convolution
Python
3
star
45

torch-same-pad

Paddings used for converting TensorFlow conv/pool layers to PyTorch.
Python
3
star
46

gitbook-plugin-meta

Add meta data to <head> for your gitbook.
HTML
3
star
47

toy-auto-diff

Toy implementation of automatic differentiation
Python
3
star
48

keras-piecewise-pooling

Piecewise pooling layer in Keras
Python
2
star
49

keras-piecewise

A wrapper layer for splitting and accumulating sequential data.
Python
2
star
50

keras-perturbation

A demonstration of perturbation of data
Python
2
star
51

parse-toys

Parsing toys
Python
2
star
52

swift-6502-core

Emulation of the 6502 CPU
Swift
2
star
53

github-action-python-test

GitHub action that runs nose tests
Dockerfile
2
star
54

CrimsonTomato

https://goo.gl/JpF6eP Pomodoro timer, sync to calendar.
Java
2
star
55

keras-succ-reg-wrapper

A wrapper that slows down the updates of trainable weights.
Python
1
star
56

torch-embed-sim

Embedding similarity in PyTorch
Python
1
star
57

torch-transformer

Transformer in PyTorch
Python
1
star
58

UChar

Basic unicode information about a character.
C++
1
star
59

CppTesting

Personal C++ testing framework.
C++
1
star