• Stars
    star
    186
  • Rank 200,799 (Top 5 %)
  • Language
    Python
  • License
    Creative Commons ...
  • Created 10 months ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

JAX implementation of the Llama 2 model

JAX Implementation of Llama 2

This project is the JAX implementation of Llama 2.

Similar Projects

Acknowledgements

This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Motivation

The objectives of this project are threefold:

  • Implement the Llama 2 model using JAX to enable efficient training and inference on Google Cloud TPU;
  • Develop a high-quality codebase that serves as an exemplary implementation of the Transformer model using JAX;
  • Facilitate the identification of common errors and inconsistencies across various transformer models through the implementation of a high-quality codebase, thereby providing valuable insights for the NLP community.

Features

Environment Setup

This project requires at least Python 3.11, JAX 0.4.14, PyTorch 2.1.0 and Transformers 4.32.0.dev0.

PyTorch and Transformers are needed for testing purposes. Additionally, the data loader depends on PyTorch DataLoader, while the profiling functionality requires TensorFlow.

Install Python 3.11

For Ubuntu users, you can follow How to install Python 3.11 on Ubuntu 22.04 to Install Python 3.11. The tutorial applied to Ubuntu 20.04 as well.

Create venv

python3.11 -m venv venv
. venv/bin/activate
pip install -U pip
pip install -U wheel

Install the proper version of JAX

You need to follow the installation instructions on JAX's offical GitHub page.

TPU:

pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

CUDA 12:

pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

CUDA 11.8:

pip install "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Install the proper version of PyTorch

Typically, you only need to install the CPU version of PyTorch since we perform most of the computation using JAX. However, it's worth noting that the current codebase's generation process is not fully optimised yet. To expedite the inference, one effective approach would involve converting the model back to Hugging Face format and running the inference in PyTorch.

To install PyTorch, you can follow the official installation guide.

CPU:

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

CUDA 12:

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121

CUDA 11.8:

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118

Install other dependencies

pip install git+https://github.com/huggingface/transformers.git
pip install -r requirements.txt

Download LLaMA weights

LLaMA 1:

If you couldn't obtain the LLaMA weights, you can download them with shawwn/llama-dl.

mkdir ../llama-weights-original && cd ../llama-weights-original
curl -o- https://raw.githubusercontent.com/shawwn/llama-dl/56f50b96072f42fb2520b1ad5a1d6ef30351f23c/llama.sh | bash
python ../llama-2-jax/venv/lib/python3.11/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir ../llama-weights-original --model_size 7B --output_dir ../llama-weights/7B

Llama 2:

You can request to access the Llama weights from the official website. After your request is approved, you will automatically get access to the Hugging Face Llama 2 models. You can verify that the models are accessible by trying to access the Llama 2 7B version.

Convert parameters

If you need to convert Llama 2 models, you need to first log in using huggingface-cli login.

python scripts/convert_params_runner.py llama1-7B
python scripts/convert_params_runner.py llama2-7B
python scripts/convert_params_runner.py llama2-70B

Special configuration for TPU Pods

If you are running on TPU pods or other multi-host environments, you need to put the IP address of other machines in external-ips.txt (one IP address per line). Besides, you should make sure that one of the hosts can SSH into other hosts.

Generation

python generate.py

On TPU pods, the command is:

./startpod python generate.py

Training

I present a simple example of the training pipeline by fine-tuning the model on the GSM dataset.

Download GSM dataset

cd .. && git clone --depth=1 https://github.com/openai/grade-school-math.git

Run the training script

python train.py

On TPU pods, the command is:

./startpod python train.py

Model Configurations

Name Parameters vocab_size n_layers n_heads_kv n_rep_kv d_model d_ff
LLaMA 1 7B 6738415616 32000 32 32 1 4096 11008
LLaMA 1 13B 32000 40 40 1 5120
LLaMA 1 33B 32000 60 52 1 6656
LLaMA 1 65B 32000 80 64 1 8192
Llama 2 7B 6738415616 32000 32 32 1 4096 11008
Llama 2 13B 32000
Llama 2 70B 32000 80 8 8 8192 28672
  n_parameters
= 2 * vocab_size * d_model
+ (2 * n_layers + 1) * d_model
+ 2 * n_layers * d_model * n_rep_kv * n_heads_kv * d_k
+ 2 * n_layers * d_model * n_heads_kv * d_k
+ 3 * n_layers * d_model * d_ff

Model Architecture

LLaMA 1 (7B)

Hugging Face format:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

The format used in this project:

model
  embedding: (32000, 4096)
  decoder: decoder_block
    input_norm: (32, 4096)
    attention
      q_proj: (32, 4096, 1, 32, 128)
      k_proj: (32, 4096, 32, 128)
      v_proj: (32, 4096, 32, 128)
      out_proj: (32, 1, 32, 128, 4096)
    post_attn_norm: (32, 4096)
    gate_proj: (32, 4096, 11008)
    up_proj: (32, 4096, 11008)
    down_proj: (32, 11008, 4096)
  norm: (4096)
lm_head: (4096, 32000)

Llama 2 (70B)

Hugging Face format:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 8192, padding_idx=0)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (v_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (o_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (up_proj): Linear(in_features=8192, out_features=28672, bias=False)
          (down_proj): Linear(in_features=28672, out_features=8192, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=8192, out_features=32000, bias=False)
)

The format used in this project:

model
  embedding: (32000, 8192)
  decoder: decoder_block
    input_norm: (80, 8192)
    attention
      q_proj: (80, 8192, 8, 8, 128)
      k_proj: (80, 8192, 8, 128)
      v_proj: (80, 8192, 8, 128)
      out_proj: (80, 8, 8, 128, 8192)
    post_attn_norm: (80, 8192)
    gate_proj: (80, 8192, 28672)
    up_proj: (80, 8192, 28672)
    down_proj: (80, 28672, 8192)
  norm: (8192)
lm_head: (8192, 32000)

Findings

  • LLaMA utilises rotary positional embeddings.
  • There is no bias in the Q, K, V matrices and the linear projections in the FFNs, which is the same as the original transformer, but different from BERT and BART.
  • In Llama models, each FFN has 3 linear projections, while in BART there are only 2.
  • There is no dropout in the original LLaMA implementation.
  • Llama 2 70B utilises Grouped-Query Attention (GQA).

More Repositories

1

ChatGPTAPIFree

A simple and open-source proxy API that allows you to access OpenAI's ChatGPT API for free!
JavaScript
3,118
star
2

awesome-rime

A curated list of Rime IME schemata and configs | Rime 輸入法方案和配置列表
886
star
3

tpu-starter

Everything you want to know about Google Cloud TPU
Python
452
star
4

FanWunMing

A Simplified-Chinese-to-Traditional-Chinese font based on GenYoMin, which can handle the one-to-many problem | 繁媛明朝是基於源樣明體開發的簡轉繁字型,能處理一簡對多繁
HTML
218
star
5

jax-smi

JAX Synergistic Memory Inspector
Python
141
star
6

FanWunHak

A Simplified-Chinese-to-Traditional-Chinese font based on GenYoGothic, which can handle the one-to-many problem | 繁媛黑體是基於源樣黑體開發的簡轉繁字型,能處理一簡對多繁
HTML
99
star
7

TinyPE-on-Win10

Smallest (268 bytes) 64-bit Portable Executable (PE) file that displays a message box on Windows 10
Assembly
83
star
8

TrAVis

TrAVis: Visualise BERT attention in your browser
Python
50
star
9

TransCan

An English-to-Cantonese machine translation model
Python
42
star
10

ByteVid

Say goodbye to long and boring videos 👋
Python
35
star
11

bart-base-jax

JAX implementation of the bart-base model
Python
29
star
12

cantoseg

Cantonese segmentation tool 粵語分詞工具
Python
26
star
13

lihkg-scraper

A Python script for scraping LIHKG
Python
25
star
14

librime-python

Python Plug-in for the Rime Input Method Engine
C++
23
star
15

bert-tokenizer-cantonese

BERT Tokenizer with vocabulary tailored for Cantonese
Python
16
star
16

llama-jax

JAX implementation of LLaMA, aiming to train LLaMA on Google Cloud TPU
Python
14
star
17

ayaka-site

Personal website deployed at https://ayaka.shn.hk/
CSS
14
star
18

basehangul-online

Online BaseHangul Encoder And Decoder
HTML
12
star
19

einshard

Einsum-like high-level array sharding API for JAX
Python
12
star
20

gpt4-cantonese-english-translator

A Cantonese-English translator based on prompt engineering
Python
11
star
21

telegram-translate-bot

Telegram translation bot @suginatransbot
Python
10
star
22

wordshk-parallel-corpus

A Cantonese-English parallel corpus extracted from words.hk
Python
10
star
23

rime-putonghua

rime 有声调普通话拼音方案
C++
8
star
24

nya-calendar

Implementation of the Nya Calendar, a lunisolar-mercurial calendar that considers the synodic period of the Earth and Mercury
Python
8
star
25

en-ayaka-site

Personal website deployed at https://en.ayaka.shn.hk/
HTML
8
star
26

abc-cantonese-parallel-corpus

Cantonese-English Parallel Corpus (extracted from the ABC Dictionary)
Python
8
star
27

smart-home

Source code for Ayaka's smart home AI assistant
Python
6
star
28

bart-base-cantonese

The pre-trained Cantonese BART model
Python
6
star
29

OpenCCFontGenerator

A generator of Simplified-Chinese-to-Traditional-Chinese fonts | 簡轉繁字型生成工具
Python
6
star
30

uyghur-practice

Uyghur Online Practice System
JavaScript
5
star
31

VunsioNewsList

List of daily news videos in Vunsio Hainanese
Python
4
star
32

bert-related-paper-abstracts

A curated list of abstracts of BERT-related papers | BERT 相关论文摘要一览表
Python
4
star
33

t5-finetuning-qa

T5 Fine-Tuning on QA Dataset
Python
4
star
34

inject-xdi8

A browser extension that adds Xdi8 on Chinese characters
JavaScript
4
star
35

SNHakkaNews

Video list of Shang Shong Rid Sien, a daily news programme in Shin Neng Hakka
Python
4
star
36

ayaka14732

The special repository whose README.md will appear on my public profile
3
star
37

bart-jax

JAX implementation of BART, aiming to demonstrate how Transformer-based models can be implemented using JAX and trained on Google Cloud TPUs
Python
3
star
38

graphviz-server

Render Graphviz images from web requests
Python
3
star
39

twblg-translate

Taiwanese Hokkien Translator 華語-臺語機械翻譯工具
HTML
2
star
40

rime-ayaka-v8

綾香思考用語輸入方案 Ayaka’s Thinking Language Input Schema
2
star
41

hls-simple

A simple HLS loop streaming server in Haskell
Haskell
2
star
42

research-blog

Ayaka’s Research Blog
CSS
2
star
43

wakong

Wakong: A mathematically-rigorous and robust masking algorithm for generating the training objective of text infilling
Python
2
star
44

ayaka14732.github.io

GitHub Pages of Ayaka
HTML
2
star
45

cs224n-a4

A decent solution to Assignment #4 of CS 224n, Winter 2022 (Cherokee NMT)
Python
2
star
46

mg-classifier

文言文-現代文分類器
Python
1
star
47

yue-cmn-classification-task

Cantonese/Mandarin Classification Task
Python
1
star
48

large-tooltip

Create a selectable large tooltip on the web page
JavaScript
1
star
49

cdn

GitHub Pages + Cloudflare as a CDN
1
star
50

translate-bot

A Telegram bot that translates English messages to Danish and German
Python
1
star
51

rime-ayaka-2021

rime 中古漢語拼音(綾香 2021 版)方案
JavaScript
1
star
52

source-han-serif-k-without-locl

Build script that generates the Source Han Serif K font without the locl feature for Hani script
Python
1
star
53

malacology-rss

Scripts for forwarding https://diary.malacology.net/ to Telegram
Python
1
star
54

tibetan-practice

Online Tibetan Practice System
JavaScript
1
star
55

MyRimeConfig

Ayaka's rime config
Python
1
star