RetNet
A huggingface transformer compatible implementation of Retention Networks. (https://arxiv.org/pdf/2307.08621.pdf) The implementation is on par with the official implementation at torchscale repo.
Supports three types of implementations: parallel
, recurrent
, chunkwise
.
Check play.ipynb
for minimal testing of parallel, recurrent, and chunkwise forward.
- The
chunkwise
produces slightly different result thanparallel
andrecurrent
for now.
Getting Started
Using PyTorch
and huggingface transformers
. Also, we need timm
for droppath
in torchscale.
pip install torch transformers timm
# pip install apex (optional)
# pip install pytest (to run tests/)
# pip install fire (to run convert_weights.py)
You may want to use conda
.
Quick Examples
Take a look at play.ipynb
.
import torch
from retnet.modeling_retnet import RetNetModel
from retnet.configuration_retnet import RetNetConfig
config = RetNetConfig(decoder_layers=8,
decoder_embed_dim=512,
decoder_value_embed_dim=1024,
decoder_retention_heads=4,
decoder_ffn_embed_dim=1024)
model = RetNetModel(config)
input_ids = torch.LongTensor([[1,2,3,4,5,6,7,8]])
parallel_outputs = model(input_ids, forward_impl='parallel', use_cache=True)
parallel_state = parallel_outputs.last_hidden_state
parallel_cache = parallel_outputs.past_key_values
past_kv = None
rnn_state = []
for i in range(input_ids.shape[1]):
rnn_out = model(input_ids[:, :i+1], forward_impl='recurrent', past_key_values=past_kv, use_cache=True)
rnn_state.append(rnn_out.last_hidden_state)
past_kv = rnn_out.past_key_values
rnn_state = torch.cat(rnn_state, dim=1)
rnn_cache = rnn_out.past_key_values
chunk_outputs = model(input_ids, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values
Language Generation
import torch
from retnet.modeling_retnet import RetNetForCausalLM
from retnet.configuration_retnet import load_config_from_json
from transformers import AutoTokenizer
config = load_config_from_json('configs/retnet-base/config.json')
model = RetNetForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = 4096
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer("Retention refers to", return_tensors='pt')
# parallel forward
# our custom generate function
generated = model.custom_generate(**inputs, parallel_compute_prompt=True, max_new_tokens=20)
# huggingface's generate. Both should be equivalent
generated = model.generate(**inputs, max_new_tokens=20)
tokenizer.batch_decode(generated)
# NOTE: this should be gibberish, since the model is not trained.
parallel_compute_prompt = (default: True)
: Thanks to parallel forward being able to computepast_kv
, we can compute parallel forward first, then feed thepast_kv
in to recurrent forward, which can save number of forwards for GPU with enough memory.
Huggingface Integration
Now the model supports full huggingface integration (except for things I don't realize :)).
It can be trained with huggingface Trainer, can be saved and loaded with save_pretrained
or
from_pretrained
, generate with .generate
.
Minimal Training Example
You can train RetNet with huggingface Trainer
API. Refer to train.py
.
export CUDA_VISIBLE_DEVICES=0
python train.py \
--model_size 300m \
--output_dir checkpoints \
--do_train --do_eval \
--prediction_loss_only \
--remove_unused_columns False \
--learning_rate 6e-4 \
--weight_decay 0.01 \
--max_steps 20000 \
--logging_steps 10 \
--eval_steps 1000 \
--save_steps 1000 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16
Some Useful Notes
xpos note
The authors mention xpos as
Since xpos (which builds on RoPE) precisely does such a rotation, this is in fact, xpos.
I used the implementation of xpos fould in torchscale
repo with 1 small change:
instead of negative min_pos
, I used min_pos=0
(line 53, 54), so that it is
recurrence friendly.
Decay Note
Equation 7 omits an important detail: there should be an extra decay applied to
This is implemented in the chunkwise_retention
function, named as intra_decay
.
This idea can also be applied to parallel_retention
to obtain the correct past_kv
that can be
further fed into recurrent or chunkwise retention in the next token steps.
Configs
The configs/
folder includes example configurations listed in the paper found in torchscale repo for
different sizes. For simplicity, I used GPT2 tokenizer, and hence the model
has 50257 as vocab size for default (this can change when microsoft release the official
weight).