• This repository has been archived on 24/Feb/2024
  • Stars
    star
    460
  • Rank 95,202 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated 9 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Open-source pre-training implementation of Google's LaMDA in PyTorch. Adding RLHF similar to ChatGPT.

Update (WIP)

I will be adding significant updates to this repository to include:

  • RLHF (Reinforcement learning with human feedback)
  • Use Decoder weights from HuggingFace t5 (Big thanks to Jason Phang)
  • Add LoRA
  • Integration with Web Search APIs
  • External database integration
  • Chain-of-thought prompting
  • Integration with a Calculator API
  • Remove ColossalAI for now. Just pure PyTorch
  • Make fixes to the dataloader. Use OpenWebText instead

LaMDA-pytorch

Open-source pre-training implementation of Google's LaMDA research paper in PyTorch. The totally not sentient AI. This repository will cover the 2B parameter implementation of the pre-training architecture as that is likely what most can afford to train. You can review Google's latest blog post from 2022 which details LaMDA here. You can also view their previous blog post from 2021 on the model here.

Acknowledgement:

I have been greatly inspired by the work of Dr. Phil 'Lucid' Wang. Please check out his open-source implementations of multiple different transformer architectures and support his work.

Developer Updates

Developer updates can be found on:

Basic Usage - Pre-training

lamda_base = LaMDA(
    num_tokens = 20000,
    dim = 512,
    dim_head = 64,
    depth = 12,
    heads = 8
)

lamda = AutoregressiveWrapper(lamda_base, max_seq_len = 512)

tokens = torch.randint(0, 20000, (1, 512)) # mock token data

logits = lamda(tokens)

print(logits)

Notes on training at scale:

About LaMDA:

  • T5 Relative Positional Bias in Attention
  • Gated GELU Activation in the Feed forward layer
  • GPT-like Decoder Only architecture
  • Autoregressive with Top-k sampling
  • Sentencepiece Byte-pair encoded tokenizer

TODO:

  • Finish building pre-training model architecture
  • Add pre-training script
  • Integrate Huggingface datasets
  • Use The Pile from Eleuther AI.
  • Build the GODEL dataset and upload to HuggingFace datasets
  • Implement GPT-2 tokenizer
  • Add Sentencepiece tokenizer training script and integration
  • Add detailed documentation
  • Add logging with Weights And Biases
  • Add scaling with ColossalAI.
  • Add finetuning script
  • Add pip installer with PyPI
  • Implement a JAX / Flax version as well
  • Add inference only if someone wants to open-source LaMDA model weights

Author

  • Enrico Shippole

Citations

@article{DBLP:journals/corr/abs-2201-08239,
  author    = {Romal Thoppilan and
               Daniel De Freitas and
               Jamie Hall and
               Noam Shazeer and
               Apoorv Kulshreshtha and
               Heng{-}Tze Cheng and
               Alicia Jin and
               Taylor Bos and
               Leslie Baker and
               Yu Du and
               YaGuang Li and
               Hongrae Lee and
               Huaixiu Steven Zheng and
               Amin Ghafouri and
               Marcelo Menegali and
               Yanping Huang and
               Maxim Krikun and
               Dmitry Lepikhin and
               James Qin and
               Dehao Chen and
               Yuanzhong Xu and
               Zhifeng Chen and
               Adam Roberts and
               Maarten Bosma and
               Yanqi Zhou and
               Chung{-}Ching Chang and
               Igor Krivokon and
               Will Rusch and
               Marc Pickett and
               Kathleen S. Meier{-}Hellstern and
               Meredith Ringel Morris and
               Tulsee Doshi and
               Renelito Delos Santos and
               Toju Duke and
               Johnny Soraker and
               Ben Zevenbergen and
               Vinodkumar Prabhakaran and
               Mark Diaz and
               Ben Hutchinson and
               Kristen Olson and
               Alejandra Molina and
               Erin Hoffman{-}John and
               Josh Lee and
               Lora Aroyo and
               Ravi Rajakumar and
               Alena Butryna and
               Matthew Lamm and
               Viktoriya Kuzmina and
               Joe Fenton and
               Aaron Cohen and
               Rachel Bernstein and
               Ray Kurzweil and
               Blaise Aguera{-}Arcas and
               Claire Cui and
               Marian Croak and
               Ed H. Chi and
               Quoc Le},
  title     = {LaMDA: Language Models for Dialog Applications},
  journal   = {CoRR},
  volume    = {abs/2201.08239},
  year      = {2022},
  url       = {https://arxiv.org/abs/2201.08239},
  eprinttype = {arXiv},
  eprint    = {2201.08239},
  timestamp = {Fri, 22 Apr 2022 16:06:31 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2201-08239.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
@misc{https://doi.org/10.48550/arxiv.1706.03762,
  doi = {10.48550/ARXIV.1706.03762},
  
  url = {https://arxiv.org/abs/1706.03762},
  
  author = {Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N. and Kaiser, Lukasz and Polosukhin, Illia},
  
  keywords = {Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Attention Is All You Need},
  
  publisher = {arXiv},
  
  year = {2017},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}
@misc{https://doi.org/10.48550/arxiv.1910.10683,
  doi = {10.48550/ARXIV.1910.10683},
  
  url = {https://arxiv.org/abs/1910.10683},
  
  author = {Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.},
  
  keywords = {Machine Learning (cs.LG), Computation and Language (cs.CL), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
  
  publisher = {arXiv},
  
  year = {2019},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}
@misc{https://doi.org/10.48550/arxiv.2002.05202,
  doi = {10.48550/ARXIV.2002.05202},
  
  url = {https://arxiv.org/abs/2002.05202},
  
  author = {Shazeer, Noam},
  
  keywords = {Machine Learning (cs.LG), Neural and Evolutionary Computing (cs.NE), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {GLU Variants Improve Transformer},
  
  publisher = {arXiv},
  
  year = {2020},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{DBLP:journals/corr/abs-2101-00027,
  author    = {Leo Gao and
               Stella Biderman and
               Sid Black and
               Laurence Golding and
               Travis Hoppe and
               Charles Foster and
               Jason Phang and
               Horace He and
               Anish Thite and
               Noa Nabeshima and
               Shawn Presser and
               Connor Leahy},
  title     = {The Pile: An 800GB Dataset of Diverse Text for Language Modeling},
  journal   = {CoRR},
  volume    = {abs/2101.00027},
  year      = {2021},
  url       = {https://arxiv.org/abs/2101.00027},
  eprinttype = {arXiv},
  eprint    = {2101.00027},
  timestamp = {Thu, 14 Oct 2021 09:16:12 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2101-00027.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{DBLP:journals/corr/abs-1808-06226,
  author    = {Taku Kudo and
               John Richardson},
  title     = {SentencePiece: {A} simple and language independent subword tokenizer
               and detokenizer for Neural Text Processing},
  journal   = {CoRR},
  volume    = {abs/1808.06226},
  year      = {2018},
  url       = {http://arxiv.org/abs/1808.06226},
  eprinttype = {arXiv},
  eprint    = {1808.06226},
  timestamp = {Sun, 02 Sep 2018 15:01:56 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1808-06226.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
@inproceedings{sennrich-etal-2016-neural,
    title = "Neural Machine Translation of Rare Words with Subword Units",
    author = "Sennrich, Rico  and
      Haddow, Barry  and
      Birch, Alexandra",
    booktitle = "Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = aug,
    year = "2016",
    address = "Berlin, Germany",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/P16-1162",
    doi = "10.18653/v1/P16-1162",
    pages = "1715--1725",
}

More Repositories

1

PaLM

An open-source implementation of Google's PaLM models
Python
805
star
2

toolformer

Python
336
star
3

t5-pytorch

Implementation of Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer in PyTorch.
Python
46
star
4

flash-gpt

Add Flash-Attention to Huggingface Models
Python
33
star
5

vit-flax

Implementation of numerous Vision Transformers in Google's JAX and Flax.
Python
19
star
6

PaLM-flax

Implementation of the SOTA Transformer architecture from PaLM - Scaling Language Modeling with Pathways in JAX/Flax
Python
14
star
7

PaLM-rlhf-jax

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM. Built in collaboration with Lucidrains.
Python
9
star
8

ViT-Patch-Merger

Python
7
star
9

SemDeDup

An unofficial implementation of SemDeDup: Data-efficient learning at web-scale through semantic deduplication.
Python
6
star
10

Huggingface-deduplicate

Python
5
star
11

hf_fsdp

A very basic fsdp wrapper for HF.
Python
5
star
12

Simple-ViT-flax

Python
4
star
13

flan-Llama

Code for training Llama on the Flan Collection
Python
4
star
14

dsp-langchain

Python
4
star
15

Sparrow-rlhf-pytorch

An open-source implementation of DeepMind's Sparrow with RLHF.
Python
4
star
16

Adaptive-Token-Sampling-Flax

Python
4
star
17

Llama-rlhf-pytorch

3
star
18

Perpetrator

An API for Redteaming large language models.
3
star
19

CaiT-Flax

Python
3
star
20

oig-Llama

Open Instruct fine-tuned Llama
3
star
21

jrlx

Python
2
star
22

Twins-SVT-Flax

An open-source implementation of the Twins: Revisiting the Design of Spatial Attention in Vision Transformers research paper in Google's JAX and Flax.
Python
2
star
23

t8t-pytorch

The official repository for Towards 8-bit Transformers. Stable training of LLMs at mixed 8-bit precision.
2
star
24

WebGPT

Combining LangChain with LLMs to build WebGPT.
2
star
25

BloomCoder

Python
2
star
26

podcasts-are-all-you-need

Python
2
star
27

ViT

Python
1
star
28

Crossformer-flax

Python
1
star
29

DeepViT-flax

Implementation of Deep Vision Transformer in Flax
Python
1
star
30

ViT-haiku

Python
1
star
31

LeViT-flax

Python
1
star
32

OPTCode

Python
1
star
33

bpt-pytorch

Python
1
star
34

deep-reinforcement-learning-go

Creation and implementation of an Alpha Go engine from scratch
Jupyter Notebook
1
star
35

Everything-Machine-Learning

1
star
36

Token-to-Token-ViT-flax

Python
1
star
37

ViT-Small-Datasets-flax

Python
1
star
38

job-posting-analysis

Data Science job posting analysis and resume comparison using NLP and Scikit Learn
Jupyter Notebook
1
star
39

charcuterie

A library for sampling Huggingface datasets.
1
star
40

tokenize_hf

Python
1
star
41

annotated-transformer

Python
1
star
42

Orchestrator

An API for different data cleaning tools.
Python
1
star
43

conceptofmind

Config files for my GitHub profile.
1
star
44

unsloth-bert

Python
1
star