Jaxformer
JAX library for training of large language models with data and model parallelism based on the pjit() operator on TPU-v3/v4.
Citation
Please cite:
@article{Jaxformer,
title={Jaxformer: A minimal library for training LLMs on TPU},
author={Nijkamp, Erik},
journal={arXiv},
year={2022}
}
Acknowledgments: Ben Wang, James Bradbury, Zak Stone, Bo Pang.
Models
CodeGen
350M
gs://sfr-codegen-research/checkpoints/codegen-350M-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-350M-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-350M-mono/150000
2B
gs://sfr-codegen-research/checkpoints/codegen-2B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-2B-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-2B-mono/100000
6B
gs://sfr-codegen-research/checkpoints/codegen-6B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-6B-multi/100000
gs://sfr-codegen-research/checkpoints/codegen-6B-mono/140000
Sanity TPU
import jax
jax.devices()
device_count = jax.device_count()
local_device_count = jax.local_device_count()
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
gcloud compute tpus tpu-vm ssh [email protected] --zone=us-east1-d --internal-ip --worker=all --command="pip install 'jax[tpu]==0.3.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
gcloud compute tpus tpu-vm scp test.py [email protected]:/home/erik.nijkamp/ --zone=us-east1-d --internal-ip --worker=all
gcloud compute tpus tpu-vm ssh [email protected] --zone=us-east1-d --internal-ip --worker=all --command="python3 /home/erik.nijkamp/test.py"
Training
Mode 1: CPU local
brew install [email protected]
apt install --yes python3.9 python3.9-venv
git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer
python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 -m jaxformer.train --config config/debug_cpu.json
Mode 2: TPU local
gcloud compute tpus list --zone=europe-west4-a
gcloud compute tpus tpu-vm delete sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --quiet
gcloud compute tpus tpu-vm create sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --accelerator-type=v3-8 --version=v2-alpha
gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --project <project> --worker 0
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>
git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer
./jaxformer/env/env_tpu_v3.sh
pip install -r requirements.txt
source .venv/bin/activate
python3
import jax
jax.devices()
quit()
python3 -m jaxformer.train --config config/debug_tpu_v3_8.json
Mode 3: TPU remote
gcloud beta compute --project=<project> instances create sfr-<username>-cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=<account> --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any
gcloud beta compute ssh sfr-<username>-cpu-small-us-east1-d-1 --project=<project> --zone=us-east1-d
sudo apt update
sudo apt install --yes git screen python3.9 python3.9-venv
screen -S codegen_350M_nl
curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init
ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''
export WANDB_API_KEY=<secret>
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>
git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer
python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 -m jaxformer.train --config config/codegen_350M_nl.json
gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0
Fine-tuning
TPU fine-tune
gcloud beta compute --project=<project> instances create sfr-<username>-cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=<account> --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any
gcloud beta compute ssh sfr-<username>-cpu-small-us-east1-d-1 --project=<project> --zone=us-east1-d
sudo apt update
sudo apt install --yes git screen python3.9 python3.9-venv
screen -S codegen_350M_mono
curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init
ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''
export WANDB_API_KEY=<secret>
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>
git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer
python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 -m jaxformer.train --config config/codegen_350M_multi.json
gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0
A100 fine-tune
apt install python3.8 python3.8-venv python3.8-dev
curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>
python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
pip install transformers==4.21.1 datasets==1.16.1 deepspeed==0.7.0 tensorflow-cpu==2.5.0
pip install -e .
deepspeed --num_gpus=1 jaxformer/hf/train.py
Conversion
python3 -m jaxformer.hf.convert --config=config/codegen_1B_mono.json --step=150000
Features
v1
- Data
- Stateful resumable data loading based on tfrecords without skip()
- TPU
- Provisioning of TPU clusters and virtual environment
- Code paths for both TPU-v3 and TPU-v4
- ...
- Parallelism
- Push-based single port TCP/IP protocol for orchestration and data-parallelism
- Megatron pjit() based sharding pattern across TPU boards for up to 6B parameter LLMs
- xmap() emulation mode through pjit() sharding
- Distributed checkpointing with full state recovery
- scan() for time-efficient jit'ing
- ...
- Debugging
- Abstraction layer for local/remote workers
- Local CPU debugging with TPU emulation
- Mock data iterators
- ...
- Training
- Fully resumable state and checkpointing
- WandB integration
- ...