• Stars
    star
    2,656
  • Rank 17,214 (Top 0.4 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 3 years ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

T5X

Go to T5X ReadTheDocs Documentation Page.

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the T5 codebase (based on Mesh TensorFlow) in JAX and Flax. To learn more, see the T5X Paper.

Below is a quick start guide for training models with TPUs on Google Cloud. For additional tutorials and background, see the complete documentation.

Quickstart (Recommended)

T5X can be run with XManager on Vertex AI. Vertex AI is a platform for training that creates TPU instances and runs code on the TPUs. Vertex AI will also shut down the TPUs when the jobs terminate. This is signifcantly easier than managing GCE VMs and TPU VM instances.

  1. Follow the pre-requisites and directions to install XManager.

  2. Request TPU quota as required. GCP projects come with 8 cores by default, which is enough to run one training experiment on a single TPU host. If you want to run multi-host training or run multiple trials in parallel, you will need more quota. Navigate to Quotas.

The quota you want is:

  • Service: Vertex AI API
  • Dimensions (location): us-central1
  • If you want to run single-host experiments:
    • Custom model training TPU V2 cores per region
    • Custom model training TPU V3 cores per region
  • If you want to run multi-host experiments:
    • Custom model training TPU V2 pod cores per region
    • Custom model training TPU V3 pod cores per region

TIP: You won't be able to run single-host experiments with multi-host quota. (i.e. you can't run tpu_v2=8 using TPU V2 pod)

  1. Launch the xmanager script located at t5x/scripts/xm_launch.py.

As a running example, we use the WMT14 En-De translation which is described in more detail in the Examples section below.

export GOOGLE_CLOUD_BUCKET_NAME=...
export TFDS_DATA_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data
export MODEL_DIR=gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d)

# Pre-download dataset in multi-host experiments.
tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR

git clone https://github.com/google-research/t5x
cd ./t5x/

python3 ./t5x/scripts/xm_launch.py \
  --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \
  --model_dir=$MODEL_DIR \
  --tfds_data_dir=$TFDS_DATA_DIR

Check gs://$GOOGLE_CLOUD_BUCKET_NAME/t5x/ for the output artifacts, which can be read by TensorBoard.

GPU Usage

UPDATE!: Nvidia has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements here: NVIDIA Rosetta

T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at t5x/contrib/gpu/scripts_gpu. The t5x/contrib/gpu/scripts_gpu folder contains example scripts for pretraining T5X on The Pile and for finetuning on SQuAD and MNLI. These scripts and associated gin configurations also contain additional GPU optimizations for better throughput.

Installation

Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.

  1. Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.

    Note: T5X also works with GPU, please follow instructions in t5x/contrib/gpu/scripts_gpu if you'd like to use GPU version.

  2. Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e., --accelerator-type=v3-8) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures.

  3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}

    where TPU_NAME and ZONE are the name and the zone used in step 2.

  4. Install T5X and the dependencies.

    git clone --branch=main https://github.com/google-research/t5x
    cd t5x
    
    python3 -m pip install -e '.[tpu]' -f \
      https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Create Google Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.

  6. (optional) If you prefer working with Jupyter/Colab style environment you can setup a custom Colab runtime by following steps from t5x/notebooks.

Example: English to German translation

As a running example, we use the WMT14 En-De translation. The raw dataset is available in TensorFlow Datasets as "wmt_t2t_translate".

T5 casts the translation task such as the following

{'en': 'That is good.', 'de': 'Das ist gut.'}

to the form called "text-to-text":

{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}

This formulation allows many different classes of language tasks to be expressed in a uniform manner and a single encoder-decoder architecture can handle them without any task-specific parameters. For more detail, refer to the T5 paper (Raffel et al. 2019).

For a scalable data pipeline and an evaluation framework, we use SeqIO, which was factored out of the T5 library. A seqio.Task packages together the raw dataset, vocabulary, preprocessing such as tokenization and evaluation metrics such as BLEU and provides a tf.data instance.

The T5 library provides a number of seqio.Tasks that were used in the T5 paper. In this example, we use wmt_t2t_ende_v003.

Before training or fine-tuning you need to download ["wmt_t2t_translate"] (https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate) dataset first.

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."

# Make sure that dataset package is up-to-date.
python3 -m pip install --upgrade tfds-nightly

# Pre-download dataset.
tfds build wmt_t2t_translate ${TFDS_DATA_DIR}

Training

To run a training job, we use the t5x/train.py script.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.
TFDS_DATA_DIR="..."

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

The configuration for this training run is defined in the Gin file base_wmt_from_scratch.gin. Gin-config is a library to handle configurations based on dependency injection. Among many benefits, Gin allows users to pass custom components such as a custom model to the T5X library without having to modify the core library. The custom components section shows how this is done.

While the core library is independent of Gin, it is central to the examples we provide. Therefore, we provide a short introduction to Gin in the context of T5X. All the configurations are written to a file "config.gin" in MODEL_DIR. This makes debugging as well as reproducing the experiment much easier.

In addition to the config.json, model-info.txt file summarizes the model parameters (shape, names of the axes, partitioning info) as well as the optimizer states.

TensorBoard

To monitor the training in TensorBoard, it is much easier (due to authentification issues) to launch the TensorBoard on your own machine and not in the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the TensorBoard with the logdir pointing to the MODEL_DIR.

# NB: run this on your machine not TPU VM!
MODEL_DIR="..."  # Copy from the TPU VM.
tensorboard --logdir=${MODEL_DIR}

Or you can launch the TensorBoard inside a Colab. In a Colab cell, run

from google.colab import auth
auth.authenticate_user()

to authorize the Colab to access the GCS bucket and launch the TensorBoard.

%load_ext tensorboard
model_dir = "..."  # Copy from the TPU VM.
%tensorboard --logdir=model_dir

Fine-tuning

We can leverage the benefits of self-supervised pre-training by initializing from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_finetune.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note: when supplying a string, dict, list, tuple value, or a bash variable via a flag, you must put it in quotes. In the case of strings, it requires escaped quotes (\"<string>\"). For example: --gin.utils.DatasetConfig.split=\"validation\" or --gin.MODEL_DIR=\"${MODEL_DIR}\".

Gin makes it easy to change a number of configurations. For example, you can change the partitioning.PjitPartitioner.num_partitions (overriding the value in base_wmt_from_scratch.gin) to chanage the parallelism strategy and pass it as a commandline arg.

--gin.partitioning.PjitPartitioner.num_partitions=8

Evaluation

To run the offline (i.e. without training) evaluation, you can use t5x/eval.py script.

EVAL_OUTPUT_DIR="..."  # directory to write eval output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/eval.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_eval.gin" \
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
  --gin.EVAL_OUTPUT_DIR=\"${EVAL_OUTPUT_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Inference

To run inference, you can use t5x/infer.py script. Here we use the same seqio.Task, but for inference we do not use the targets features other than logging them alongside the prediction in a JSON file.

INFER_OUTPUT_DIR="..."  # directory to write infer output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/infer.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_infer.gin" \
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
  --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Exporting as TensorFlow Saved Model

Pretrained model can be exported as TensorFlow Saved Model, and deployed to Vertex AI Prediction service using [Optimized TensorFlow Runtime] (https://cloud.google.com/vertex-ai/docs/predictions/optimized-tensorflow-runtime). Please note that exported model won't work on OSS based TensorFlow Model Server.

T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
CHECKPOINT_PATH="..."

BATCH_SIZE=None
BEAM_SIZE=1

# Use 'bfloat16' if you plan to run exported model on NVIDIA A100 or newer GPUs,
# for other GPUs use 'float32'.
ACTIVATION_DTYPE=bfloat16

# Version numbers must be numeric. We generate one based on datetime.
VERSION=$(date +%Y%m%d%H%M%S)

NAME=t5x_base_${ACTIVATION_DTYPE}  # Model name.

# Path to export model to. Note that export script is going to add _cpu suffix
# after model name.
OUTPUT=${CHECKPOINT_PATH}/saved_model.${NAME}/${VERSION}

declare -a ARGS=(
--gin_file=t5x/examples/t5/t5_1_1/base.gin
--gin_file=t5x/t5x/configs/runs/export.gin
--gin.TASK_FEATURE_LENGTHS="{'inputs': 256, 'targets': 256}"
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\"
--gin.MODEL_NAME=\"/ml/${USER}/t5x_base\"
--gin.MODEL_OUTPUT_DIR=\"${OUTPUT}\"
--gin.BEAM_SIZE=${BEAM_SIZE}
--gin.BATCH_SIZE=${BATCH_SIZE}
--gin.export_lib.save.partitioner=None
--gin.export_lib.save.warmup_examples="['hello world']"
--gin.export_lib.ExportableModule.use_batch_function=False
--gin.export_lib.ExportableModule.use_gpu=False
--gin.export_lib.ExportableModule.jit_compile=False
--gin.ACTIVATION_DTYPE=\"${ACTIVATION_DTYPE}\"
--gin.network.T5Config.dtype=\"${ACTIVATION_DTYPE}\"
--gin.utils.RestoreCheckpointConfig.dtype=\"${ACTIVATION_DTYPE}\"
--gin.DROPOUT_RATE=0.0
)

(python3 ${T5X_DIR}/t5x/export.py "${ARGS[@]}")

For detailed arguments definition refer to [export.gin] (t5x/configs/runs/export.gin).

You can run XL and smaller models on NVIDIA A100 40GB, and XXL models on NVIDIA A100 80GB.

Custom components

The translation example uses the encoder-decoder model that T5X provides as well as the dataset from the T5 library. This section shows how you can use your own dataset and a model and pass via Gin.

Example: custom dataset in a user directory

For this example, we have the following directory structure with ${HOME}/dir1/user_dir representing a user directory with custom components.

${HOME}
└── dir1
 Β Β  └── user_dir
 Β Β      β”œβ”€β”€ t5_1_1_base_de_en.gin
 Β Β      └── tasks.py

As an example, let's define a new dataset. Here we use the same Translation dataset but we define the translation task in the opposite direction, i.e., German to English intead of English to German. We define this task in tasks.py

# ${HOME}/dir1/user_dir/tasks.py

import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors

vocabulary = seqio.SentencePieceVocabulary(
    'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
output_features = {
    'inputs': seqio.Feature(vocabulary=vocabulary),
    'targets': seqio.Feature(vocabulary=vocabulary)
}

seqio.TaskRegistry.add(
    'wmt_t2t_de_en_v003',
    source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
    preprocessors=[
        functools.partial(
            preprocessors.translate,
            source_language='de', target_language='en'),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=[metrics.bleu],
    output_features=output_features)

In the Gin file, most of the settings are equivalent to those used in the En->De example. So we include the Gin file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we need to import the task module "tasks.py". Note that we use a relative path defined with respect to the user directory. This will be specified as a flag.

# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
from __gin__ import dynamic_registration
import tasks  # This imports the task defined in dir1/user_dir/tasks.py.

include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin"
MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"

Finally, we launch training passing the user directory as a flag gin_search_paths such that the Gin file and python modules can be specified with relative paths.

PROJECT_DIR=${HOME}"/dir1/user_dir"
T5X_DIR="..."  # directory where the t5x is cloned.
TFDS_DATA_DIR="..."
MODEL_DIR="..."
export PYTHONPATH=${PROJECT_DIR}

python3 ${T5X_DIR}/t5x/train.py \
  --gin_search_paths=${PROJECT_DIR} \
  --gin_file="t5_1_1_base_de_en.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Checkpoints

Native Checkpoints

We have released the checkpoints of many of the original T5 models and their variants a native T5X format for maximal efficiency. See the complete list including the matching Gin configuration files.

These are converted from the public Mesh TensorFlow checkpoints .

Compatibility with the Mesh TensorFlow checkpoints

The Mesh TensorFlow checkpoints trained using the T5 library can be directly loaded into T5X. For example, we can rerun the fine-tuning example initializing from the MTF checkpoint by changing the INIT_CHECKPOINT Gin macro.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt19_ende_train.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --gin.MIXTURE_OR_TASK_NAME=\"wmt_t2t_ende_v003\" \
  --gin.INIT_CHECKPOINT=\"gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note that restoring directly from the Mesh TensorFlow checkpoints can be inefficient if heavy model parallelism is used for large models. This is because each host loads the entire copy of the model first and then keep only the relevant slices dictated by the model parallelism specification. If you have Mesh TensorFlow checkpoints that you run often, we recommend converting the checkpoints to T5X native format using the convert_tf_checkpoint script.

Citing T5X

Please use the following bibtex entry to cite T5X.

@article{roberts2022t5x,
  url = {https://arxiv.org/abs/2203.17189},
  author = {Roberts, Adam and Chung, Hyung Won and Levskaya, Anselm and Mishra, Gaurav and Bradbury, James and Andor, Daniel and Narang, Sharan and Lester, Brian and Gaffney, Colin and Mohiuddin, Afroz and Hawthorne, Curtis and Lewkowycz, Aitor and Salcianu, Alex and van Zee, Marc and Austin, Jacob and Goodman, Sebastian and Soares, Livio Baldini and Hu, Haitang and Tsvyashchenko, Sasha and Chowdhery, Aakanksha and Bastings, Jasmijn and Bulian, Jannis and Garcia, Xavier and Ni, Jianmo and Chen, Andrew and Kenealy, Kathleen and Clark, Jonathan H. and Lee, Stephan and Garrette, Dan and Lee-Thorp, James and Raffel, Colin and Shazeer, Noam and Ritter, Marvin and Bosma, Maarten and Passos, Alexandre and Maitin-Shepard, Jeremy and Fiedel, Noah and Omernick, Mark and Saeta, Brennan and Sepassi, Ryan and Spiridonov, Alexander and Newlan, Joshua and Gesmundo, Andrea},
  title = {Scaling Up Models and Data with $\texttt{t5x}$ and $\texttt{seqio}$},
  journal={arXiv preprint arXiv:2203.17189},
  year = {2022},
}

Note

This is not an officially supported Google product

More Repositories

1

bert

TensorFlow code and pre-trained models for BERT
Python
37,769
star
2

google-research

Google Research
Jupyter Notebook
33,759
star
3

tuning_playbook

A playbook for systematically maximizing the performance of deep learning models.
26,593
star
4

vision_transformer

Jupyter Notebook
10,251
star
5

text-to-text-transfer-transformer

Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"
Python
6,099
star
6

arxiv-latex-cleaner

arXiv LaTeX Cleaner: Easily clean the LaTeX code of your paper to submit to arXiv
Python
5,233
star
7

simclr

SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
Jupyter Notebook
3,937
star
8

multinerf

A Code Release for Mip-NeRF 360, Ref-NeRF, and RawNeRF
Python
3,612
star
9

timesfm

TimesFM (Time Series Foundation Model) is a pretrained time-series foundation model developed by Google Research for time-series forecasting.
Python
3,576
star
10

scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Python
3,295
star
11

football

Check out the new game server:
Python
3,260
star
12

albert

ALBERT: A Lite BERT for Self-supervised Learning of Language Representations
Python
3,209
star
13

frame-interpolation

FILM: Frame Interpolation for Large Motion, In ECCV 2022.
Python
2,818
star
14

electra

ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
Python
2,325
star
15

kubric

A data generation pipeline for creating semi-realistic synthetic multi-object videos with rich annotations such as instance segmentation masks, depth maps, and optical flow.
Jupyter Notebook
2,312
star
16

big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Jupyter Notebook
2,219
star
17

uda

Unsupervised Data Augmentation (UDA)
Python
2,131
star
18

language

Shared repository for open-sourced projects from the Google AI Language team.
Python
1,605
star
19

pegasus

Python
1,600
star
20

dex-lang

Research language for array processing in the Haskell/ML family
Haskell
1,581
star
21

torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Python
1,548
star
22

parti

1,538
star
23

big_transfer

Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
Python
1,504
star
24

FLAN

Python
1,460
star
25

robotics_transformer

Python
1,337
star
26

disentanglement_lib

disentanglement_lib is an open-source library for research on learning disentangled representations.
Python
1,311
star
27

multilingual-t5

Python
1,197
star
28

circuit_training

Python
1,151
star
29

tapas

End-to-end neural table-text understanding models.
Python
1,143
star
30

planet

Learning Latent Dynamics for Planning from Pixels
Python
1,134
star
31

mixmatch

Python
1,130
star
32

deduplicate-text-datasets

Rust
1,104
star
33

fixmatch

A simple method to perform semi-supervised learning with limited data.
Python
1,094
star
34

morph-net

Fast & Simple Resource-Constrained Learning of Deep Network Structure
Python
1,016
star
35

maxim

[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement.
Python
996
star
36

deeplab2

DeepLab2 is a TensorFlow library for deep labeling, aiming to provide a unified and state-of-the-art TensorFlow codebase for dense pixel labeling tasks.
Python
995
star
37

batch-ppo

Efficient Batched Reinforcement Learning in TensorFlow
Python
963
star
38

augmix

AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty
Python
951
star
39

magvit

Official JAX implementation of MAGVIT: Masked Generative Video Transformer
Python
947
star
40

pix2seq

Pix2Seq codebase: multi-tasks with generative modeling (autoregressive and diffusion)
Jupyter Notebook
865
star
41

seed_rl

SEED RL: Scalable and Efficient Deep-RL with Accelerated Central Inference. Implements IMPALA and R2D2 algorithms in TF2 with SEED's architecture.
Python
793
star
42

meta-dataset

A dataset of datasets for learning to learn from few examples
Jupyter Notebook
762
star
43

noisystudent

Code for Noisy Student Training. https://arxiv.org/abs/1911.04252
Python
751
star
44

rliable

[NeurIPS'21 Outstanding Paper] Library for reliable evaluation on RL and ML benchmarks, even with only a handful of seeds.
Jupyter Notebook
747
star
45

recsim

A Configurable Recommender Systems Simulation Platform
Python
739
star
46

jax3d

Python
733
star
47

long-range-arena

Long Range Arena for Benchmarking Efficient Transformers
Python
719
star
48

lottery-ticket-hypothesis

A reimplementation of "The Lottery Ticket Hypothesis" (Frankle and Carbin) on MNIST.
Python
706
star
49

federated

A collection of Google research projects related to Federated Learning and Federated Analytics.
Python
675
star
50

bleurt

BLEURT is a metric for Natural Language Generation based on transfer learning.
Python
651
star
51

prompt-tuning

Original Implementation of Prompt Tuning from Lester, et al, 2021
Python
642
star
52

nasbench

NASBench: A Neural Architecture Search Dataset and Benchmark
Python
641
star
53

neuralgcm

Hybrid ML + physics model of the Earth's atmosphere
Python
641
star
54

xtreme

XTREME is a benchmark for the evaluation of the cross-lingual generalization ability of pre-trained multilingual models that covers 40 typologically diverse languages and includes nine tasks.
Python
631
star
55

lasertagger

Python
606
star
56

sound-separation

Python
603
star
57

pix2struct

Python
587
star
58

vmoe

Jupyter Notebook
569
star
59

dreamer

Dream to Control: Learning Behaviors by Latent Imagination
Python
568
star
60

robopianist

[CoRL '23] Dexterous piano playing with deep reinforcement learning.
Python
562
star
61

omniglue

Code release for CVPR'24 submission 'OmniGlue'
Python
561
star
62

fast-soft-sort

Fast Differentiable Sorting and Ranking
Python
561
star
63

ravens

Train robotic agents to learn pick and place with deep learning for vision-based manipulation in PyBullet. Transporter Nets, CoRL 2020.
Python
560
star
64

sam

Python
551
star
65

batch_rl

Offline Reinforcement Learning (aka Batch Reinforcement Learning) on Atari 2600 games
Python
521
star
66

bigbird

Transformers for Longer Sequences
Python
518
star
67

tensor2robot

Distributed machine learning infrastructure for large-scale robotics research
Python
483
star
68

byt5

Python
477
star
69

adapter-bert

Python
476
star
70

mint

Multi-modal Content Creation Model Training Infrastructure including the FACT model (AI Choreographer) implementation.
Python
465
star
71

leaf-audio

LEAF is a learnable alternative to audio features such as mel-filterbanks, that can be initialized as an approximation of mel-filterbanks, and then be trained for the task at hand, while using a very small number of parameters.
Python
446
star
72

robustness_metrics

Jupyter Notebook
442
star
73

maxvit

[ECCV 2022] Official repository for "MaxViT: Multi-Axis Vision Transformer". SOTA foundation models for classification, detection, segmentation, image quality, and generative modeling...
Jupyter Notebook
436
star
74

receptive_field

Compute receptive fields of your favorite convnets
Python
434
star
75

maskgit

Official Jax Implementation of MaskGIT
Jupyter Notebook
429
star
76

weatherbench2

A benchmark for the next generation of data-driven global weather models.
Python
420
star
77

l2p

Learning to Prompt (L2P) for Continual Learning @ CVPR22 and DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning @ ECCV22
Python
408
star
78

distilling-step-by-step

Python
407
star
79

ssl_detection

Semi-supervised learning for object detection
Python
398
star
80

nerf-from-image

Shape, Pose, and Appearance from a Single Image via Bootstrapped Radiance Field Inversion
Python
377
star
81

computation-thru-dynamics

Understanding computation in artificial and biological recurrent networks through the lens of dynamical systems.
Jupyter Notebook
369
star
82

tf-slim

Python
368
star
83

realworldrl_suite

Real-World RL Benchmark Suite
Python
341
star
84

python-graphs

A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.
Python
325
star
85

rigl

End-to-end training of sparse deep neural networks with little-to-no performance loss.
Python
314
star
86

task_adaptation

Python
310
star
87

self-organising-systems

Jupyter Notebook
308
star
88

ibc

Official implementation of Implicit Behavioral Cloning, as described in our CoRL 2021 paper, see more at https://implicitbc.github.io/
Python
306
star
89

tensorflow_constrained_optimization

Python
300
star
90

syn-rep-learn

Learning from synthetic data - code and models
Python
294
star
91

arco-era5

Recipes for reproducing Analysis-Ready & Cloud Optimized (ARCO) ERA5 datasets.
Python
291
star
92

vdm

Jupyter Notebook
291
star
93

rlds

Jupyter Notebook
284
star
94

exoplanet-ml

Machine learning models and utilities for exoplanet science.
Python
283
star
95

retvec

RETVec is an efficient, multilingual, and adversarially-robust text vectorizer.
Jupyter Notebook
281
star
96

sparf

This is the official code release for SPARF: Neural Radiance Fields from Sparse and Noisy Poses [CVPR 2023-Highlight]
Python
279
star
97

tensorflow-coder

Python
275
star
98

lm-extraction-benchmark

Python
270
star
99

language-table

Suite of human-collected datasets and a multi-task continuous control benchmark for open vocabulary visuolinguomotor learning.
Jupyter Notebook
260
star
100

falken

Falken provides developers with a service that allows them to train AI that can play their games
Python
254
star