• This repository has been archived on 07/Jul/2023
  • Stars
    star
    14,693
  • Rank 2,021 (Top 0.04 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 7 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

Tensor2Tensor

PyPI version GitHub Issues Contributions welcome Gitter License Travis Run on FH

Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

T2T was developed by researchers and engineers in the Google Brain team and a community of users. It is now deprecated β€” we keep it running and welcome bug-fixes, but encourage users to use the successor library Trax.

Quick Start

This iPython notebook explains T2T and runs in your browser using a free VM from Google, no installation needed. Alternatively, here is a one-command version that installs T2T, downloads MNIST, trains a model and evaluates it:

pip install tensor2tensor && t2t-trainer \
  --generate_data \
  --data_dir=~/t2t_data \
  --output_dir=~/t2t_train/mnist \
  --problem=image_mnist \
  --model=shake_shake \
  --hparams_set=shake_shake_quick \
  --train_steps=1000 \
  --eval_steps=100

Contents

Suggested Datasets and Models

Below we list a number of tasks that can be solved with T2T when you train the appropriate model on the appropriate problem. We give the problem and model below and we suggest a setting of hyperparameters that we know works well in our setup. We usually run either on Cloud TPUs or on 8-GPU machines; you might need to modify the hyperparameters if you run on a different setup.

Mathematical Language Understanding

For evaluating mathematical expressions at the character level involving addition, subtraction and multiplication of both positive and negative decimal numbers with variable digits assigned to symbolic variables, use

  • the MLU data-set: --problem=algorithmic_math_two_variables

You can try solving the problem with different transformer models and hyperparameters as described in the paper:

  • Standard transformer: --model=transformer --hparams_set=transformer_tiny
  • Universal transformer: --model=universal_transformer --hparams_set=universal_transformer_tiny
  • Adaptive universal transformer: --model=universal_transformer --hparams_set=adaptive_universal_transformer_tiny

Story, Question and Answer

For answering questions based on a story, use

  • the bAbi data-set: --problem=babi_qa_concat_task1_1k

You can choose the bAbi task from the range [1,20] and the subset from 1k or 10k. To combine test data from all tasks into a single test set, use --problem=babi_qa_concat_all_tasks_10k

Image Classification

For image classification, we have a number of standard data-sets:

  • ImageNet (a large data-set): --problem=image_imagenet, or one of the re-scaled versions (image_imagenet224, image_imagenet64, image_imagenet32)
  • CIFAR-10: --problem=image_cifar10 (or --problem=image_cifar10_plain to turn off data augmentation)
  • CIFAR-100: --problem=image_cifar100
  • MNIST: --problem=image_mnist

For ImageNet, we suggest to use the ResNet or Xception, i.e., use --model=resnet --hparams_set=resnet_50 or --model=xception --hparams_set=xception_base. Resnet should get to above 76% top-1 accuracy on ImageNet.

For CIFAR and MNIST, we suggest to try the shake-shake model: --model=shake_shake --hparams_set=shakeshake_big. This setting trained for --train_steps=700000 should yield close to 97% accuracy on CIFAR-10.

Image Generation

For (un)conditional image generation, we have a number of standard data-sets:

  • CelebA: --problem=img2img_celeba for image-to-image translation, namely, superresolution from 8x8 to 32x32.
  • CelebA-HQ: --problem=image_celeba256_rev for a downsampled 256x256.
  • CIFAR-10: --problem=image_cifar10_plain_gen_rev for class-conditional 32x32 generation.
  • LSUN Bedrooms: --problem=image_lsun_bedrooms_rev
  • MS-COCO: --problem=image_text_ms_coco_rev for text-to-image generation.
  • Small ImageNet (a large data-set): --problem=image_imagenet32_gen_rev for 32x32 or --problem=image_imagenet64_gen_rev for 64x64.

We suggest to use the Image Transformer, i.e., --model=imagetransformer, or the Image Transformer Plus, i.e., --model=imagetransformerpp that uses discretized mixture of logistics, or variational auto-encoder, i.e., --model=transformer_ae. For CIFAR-10, using --hparams_set=imagetransformer_cifar10_base or --hparams_set=imagetransformer_cifar10_base_dmol yields 2.90 bits per dimension. For Imagenet-32, using --hparams_set=imagetransformer_imagenet32_base yields 3.77 bits per dimension.

Language Modeling

For language modeling, we have these data-sets in T2T:

  • PTB (a small data-set): --problem=languagemodel_ptb10k for word-level modeling and --problem=languagemodel_ptb_characters for character-level modeling.
  • LM1B (a billion-word corpus): --problem=languagemodel_lm1b32k for subword-level modeling and --problem=languagemodel_lm1b_characters for character-level modeling.

We suggest to start with --model=transformer on this task and use --hparams_set=transformer_small for PTB and --hparams_set=transformer_base for LM1B.

Sentiment Analysis

For the task of recognizing the sentiment of a sentence, use

  • the IMDB data-set: --problem=sentiment_imdb

We suggest to use --model=transformer_encoder here and since it is a small data-set, try --hparams_set=transformer_tiny and train for few steps (e.g., --train_steps=2000).

Speech Recognition

For speech-to-text, we have these data-sets in T2T:

  • Librispeech (US English): --problem=librispeech for the whole set and --problem=librispeech_clean for a smaller but nicely filtered part.

  • Mozilla Common Voice (US English): --problem=common_voice for the whole set --problem=common_voice_clean for a quality-checked subset.

Summarization

For summarizing longer text into shorter one we have these data-sets:

  • CNN/DailyMail articles summarized into a few sentences: --problem=summarize_cnn_dailymail32k

We suggest to use --model=transformer and --hparams_set=transformer_prepend for this task. This yields good ROUGE scores.

Translation

There are a number of translation data-sets in T2T:

  • English-German: --problem=translate_ende_wmt32k
  • English-French: --problem=translate_enfr_wmt32k
  • English-Czech: --problem=translate_encs_wmt32k
  • English-Chinese: --problem=translate_enzh_wmt32k
  • English-Vietnamese: --problem=translate_envi_iwslt32k
  • English-Spanish: --problem=translate_enes_wmt32k

You can get translations in the other direction by appending _rev to the problem name, e.g., for German-English use --problem=translate_ende_wmt32k_rev (note that you still need to download the original data with t2t-datagen --problem=translate_ende_wmt32k).

For all translation problems, we suggest to try the Transformer model: --model=transformer. At first it is best to try the base setting, --hparams_set=transformer_base. When trained on 8 GPUs for 300K steps this should reach a BLEU score of about 28 on the English-German data-set, which is close to state-of-the art. If training on a single GPU, try the --hparams_set=transformer_base_single_gpu setting. For very good results or larger data-sets (e.g., for English-French), try the big model with --hparams_set=transformer_big.

See this example to know how the translation works.

Basics

Walkthrough

Here's a walkthrough training a good English-to-German translation model using the Transformer model from Attention Is All You Need on WMT data.

pip install tensor2tensor

# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
t2t-trainer --registry_help

PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base_single_gpu

DATA_DIR=$HOME/t2t_data
TMP_DIR=/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS

mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR

# Generate data
t2t-datagen \
  --data_dir=$DATA_DIR \
  --tmp_dir=$TMP_DIR \
  --problem=$PROBLEM

# Train
# *  If you run out of memory, add --hparams='batch_size=1024'.
t2t-trainer \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR

# Decode

DECODE_FILE=$DATA_DIR/decode_this.txt
echo "Hello world" >> $DECODE_FILE
echo "Goodbye world" >> $DECODE_FILE
echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de

BEAM_SIZE=4
ALPHA=0.6

t2t-decoder \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR \
  --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
  --decode_from_file=$DECODE_FILE \
  --decode_to_file=translation.en

# See the translations
cat translation.en

# Evaluate the BLEU score
# Note: Report this BLEU score in papers, not the internal approx_bleu metric.
t2t-bleu --translation=translation.en --reference=ref-translation.de

Installation

# Assumes tensorflow or tensorflow-gpu installed
pip install tensor2tensor

# Installs with tensorflow-gpu requirement
pip install tensor2tensor[tensorflow_gpu]

# Installs with tensorflow (cpu) requirement
pip install tensor2tensor[tensorflow]

Binaries:

# Data generator
t2t-datagen

# Trainer
t2t-trainer --registry_help

Library usage:

python -c "from tensor2tensor.models.transformer import Transformer"

Features

  • Many state of the art and baseline models are built-in and new models can be added easily (open an issue or pull request!).
  • Many datasets across modalities - text, audio, image - available for generation and use, and new ones can be added easily (open an issue or pull request for public datasets!).
  • Models can be used with any dataset and input mode (or even multiple); all modality-specific processing (e.g. embedding lookups for text tokens) is done with bottom and top transformations, which are specified per-feature in the model.
  • Support for multi-GPU machines and synchronous (1 master, many workers) and asynchronous (independent workers synchronizing through a parameter server) distributed training.
  • Easily swap amongst datasets and models by command-line flag with the data generation script t2t-datagen and the training script t2t-trainer.
  • Train on Google Cloud ML and Cloud TPUs.

T2T overview

Problems

Problems consist of features such as inputs and targets, and metadata such as each feature's modality (e.g. symbol, image, audio) and vocabularies. Problem features are given by a dataset, which is stored as a TFRecord file with tensorflow.Example protocol buffers. All problems are imported in all_problems.py or are registered with @registry.register_problem. Run t2t-datagen to see the list of available problems and download them.

Models

T2TModels define the core tensor-to-tensor computation. They apply a default transformation to each input and output so that models may deal with modality-independent tensors (e.g. embeddings at the input; and a linear transform at the output to produce logits for a softmax over classes). All models are imported in the models subpackage, inherit from T2TModel, and are registered with @registry.register_model.

Hyperparameter Sets

Hyperparameter sets are encoded in HParams objects, and are registered with @registry.register_hparams. Every model and problem has a HParams. A basic set of hyperparameters are defined in common_hparams.py and hyperparameter set functions can compose other hyperparameter set functions.

Trainer

The trainer binary is the entrypoint for training, evaluation, and inference. Users can easily switch between problems, models, and hyperparameter sets by using the --model, --problem, and --hparams_set flags. Specific hyperparameters can be overridden with the --hparams flag. --schedule and related flags control local and distributed training/evaluation (distributed training documentation).

Adding your own components

T2T's components are registered using a central registration mechanism that enables easily adding new ones and easily swapping amongst them by command-line flag. You can add your own components without editing the T2T codebase by specifying the --t2t_usr_dir flag in t2t-trainer.

You can do so for models, hyperparameter sets, modalities, and problems. Please do submit a pull request if your component might be useful to others.

See the example_usr_dir for an example user directory.

Adding a dataset

To add a new dataset, subclass Problem and register it with @registry.register_problem. See TranslateEndeWmt8k for an example. Also see the data generators README.

Run on FloydHub

Run on FloydHub

Click this button to open a Workspace on FloydHub. You can use the workspace to develop and test your code on a fully configured cloud GPU machine.

Tensor2Tensor comes preinstalled in the environment, you can simply open a Terminal and run your code.

# Test the quick-start on a Workspace's Terminal with this command
t2t-trainer \
  --generate_data \
  --data_dir=./t2t_data \
  --output_dir=./t2t_train/mnist \
  --problem=image_mnist \
  --model=shake_shake \
  --hparams_set=shake_shake_quick \
  --train_steps=1000 \
  --eval_steps=100

Note: Ensure compliance with the FloydHub Terms of Service.

Papers

When referencing Tensor2Tensor, please cite this paper.

@article{tensor2tensor,
  author    = {Ashish Vaswani and Samy Bengio and Eugene Brevdo and
    Francois Chollet and Aidan N. Gomez and Stephan Gouws and Llion Jones and
    \L{}ukasz Kaiser and Nal Kalchbrenner and Niki Parmar and Ryan Sepassi and
    Noam Shazeer and Jakob Uszkoreit},
  title     = {Tensor2Tensor for Neural Machine Translation},
  journal   = {CoRR},
  volume    = {abs/1803.07416},
  year      = {2018},
  url       = {http://arxiv.org/abs/1803.07416},
}

Tensor2Tensor was used to develop a number of state-of-the-art models and deep learning methods. Here we list some papers that were based on T2T from the start and benefited from its features and architecture in ways described in the Google Research Blog post introducing T2T.

NOTE: This is not an official Google product.

More Repositories

1

tensorflow

An Open Source Machine Learning Framework for Everyone
C++
186,123
star
2

models

Models and examples built with TensorFlow
Python
77,049
star
3

tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
TypeScript
18,430
star
4

tfjs-models

Pretrained models for TensorFlow.js
TypeScript
14,058
star
5

playground

Play with neural networks!
TypeScript
11,585
star
6

tfjs-core

WebGL-accelerated ML // linear algebra // automatic differentiation for JavaScript.
TypeScript
8,480
star
7

examples

TensorFlow examples
Jupyter Notebook
7,920
star
8

tensorboard

TensorFlow's Visualization Toolkit
TypeScript
6,686
star
9

tfjs-examples

Examples built with TensorFlow.js
JavaScript
6,553
star
10

nmt

TensorFlow Neural Machine Translation Tutorial
Python
6,315
star
11

docs

TensorFlow documentation
Jupyter Notebook
6,119
star
12

swift

Swift for TensorFlow
Jupyter Notebook
6,118
star
13

serving

A flexible, high-performance serving system for machine learning models
C++
6,068
star
14

tpu

Reference models and tools for Cloud TPUs.
Jupyter Notebook
5,214
star
15

rust

Rust language bindings for TensorFlow
Rust
4,939
star
16

lucid

A collection of infrastructure and tools for research in neural network interpretability.
Jupyter Notebook
4,611
star
17

datasets

TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
Python
4,298
star
18

probability

Probabilistic reasoning and statistical analysis in TensorFlow
Jupyter Notebook
4,053
star
19

adanet

Fast and flexible AutoML with learning guarantees.
Jupyter Notebook
3,474
star
20

hub

A library for transfer learning by reusing parts of TensorFlow models.
Python
3,467
star
21

minigo

An open-source implementation of the AlphaGoZero algorithm
C++
3,428
star
22

skflow

Simplified interface for TensorFlow (mimicking Scikit Learn) for Deep Learning
Python
3,181
star
23

lingvo

Lingvo
Python
2,812
star
24

agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Python
2,775
star
25

graphics

TensorFlow Graphics: Differentiable Graphics Layers for TensorFlow
Python
2,744
star
26

ranking

Learning to Rank in TensorFlow
Python
2,735
star
27

federated

A framework for implementing federated learning
Python
2,281
star
28

tfx

TFX is an end-to-end platform for deploying production ML pipelines
Python
2,099
star
29

privacy

Library for training machine learning models with privacy for training data
Python
1,916
star
30

tflite-micro

Infrastructure to enable deployment of ML models to low-power resource-constrained embedded targets (including microcontrollers and digital signal processors).
C++
1,887
star
31

fold

Deep learning with dynamic computation graphs in TensorFlow
Python
1,824
star
32

recommenders

TensorFlow Recommenders is a library for building recommender system models using TensorFlow.
Python
1,816
star
33

quantum

Hybrid Quantum-Classical Machine Learning in TensorFlow
Python
1,798
star
34

mlir

"Multi-Level Intermediate Representation" Compiler Infrastructure
1,720
star
35

addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Python
1,690
star
36

mesh

Mesh TensorFlow: Model Parallelism Made Easier
Python
1,589
star
37

haskell

Haskell bindings for TensorFlow
Haskell
1,558
star
38

model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
Python
1,486
star
39

workshops

A few exercises for use at events.
Jupyter Notebook
1,457
star
40

ecosystem

Integration of TensorFlow with other open-source frameworks
Scala
1,370
star
41

gnn

TensorFlow GNN is a library to build Graph Neural Networks on the TensorFlow platform.
Python
1,320
star
42

model-analysis

Model analysis tools for TensorFlow
Python
1,250
star
43

community

Stores documents used by the TensorFlow developer community
C++
1,239
star
44

text

Making text a first-class citizen in TensorFlow.
C++
1,224
star
45

benchmarks

A benchmark framework for Tensorflow
Python
1,144
star
46

tfjs-node

TensorFlow powered JavaScript library for training and deploying ML models on Node.js.
TypeScript
1,048
star
47

similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Python
1,008
star
48

transform

Input pipeline framework
Python
984
star
49

neural-structured-learning

Training neural models with structured signals.
Python
982
star
50

gan

Tooling for GANs in TensorFlow
Jupyter Notebook
907
star
51

compression

Data compression in TensorFlow
Python
849
star
52

java

Java bindings for TensorFlow
Java
818
star
53

swift-apis

Swift for TensorFlow Deep Learning Library
Swift
794
star
54

deepmath

Experiments towards neural network theorem proving
C++
779
star
55

data-validation

Library for exploring and validating machine learning data
Python
756
star
56

runtime

A performant and modular runtime for TensorFlow
C++
754
star
57

tensorrt

TensorFlow/TensorRT integration
Jupyter Notebook
736
star
58

docs-l10n

Translations of TensorFlow documentation
Jupyter Notebook
716
star
59

io

Dataset, streaming, and file system extensions maintained by TensorFlow SIG-IO
C++
698
star
60

tfjs-converter

Convert TensorFlow SavedModel and Keras models to TensorFlow.js
TypeScript
697
star
61

decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Python
656
star
62

swift-models

Models and examples built with Swift for TensorFlow
Jupyter Notebook
644
star
63

tcav

Code for the TCAV ML interpretability project
Jupyter Notebook
612
star
64

recommenders-addons

Additional utils and helpers to extend TensorFlow when build recommendation systems, contributed and maintained by SIG Recommenders.
Cuda
590
star
65

tfjs-wechat

WeChat Mini-program plugin for TensorFlow.js
TypeScript
547
star
66

flutter-tflite

Dart
534
star
67

lattice

Lattice methods in TensorFlow
Python
519
star
68

model-card-toolkit

A toolkit that streamlines and automates the generation of model cards
Python
415
star
69

mlir-hlo

MLIR
388
star
70

tflite-support

TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile / ioT devices.
C++
374
star
71

cloud

The TensorFlow Cloud repository provides APIs that will allow to easily go from debugging and training your Keras and TensorFlow code in a local environment to distributed training in the cloud.
Python
374
star
72

custom-op

Guide for building custom op for TensorFlow
Smarty
373
star
73

tfjs-vis

A set of utilities for in browser visualization with TensorFlow.js
TypeScript
360
star
74

profiler

A profiling and performance analysis tool for TensorFlow
TypeScript
359
star
75

fairness-indicators

Tensorflow's Fairness Evaluation and Visualization Toolkit
Jupyter Notebook
341
star
76

moonlight

Optical music recognition in TensorFlow
Python
325
star
77

tfjs-tsne

TypeScript
309
star
78

estimator

TensorFlow Estimator
Python
300
star
79

embedding-projector-standalone

HTML
293
star
80

tfjs-layers

TensorFlow.js high-level layers API
TypeScript
283
star
81

build

Build-related tools for TensorFlow
Shell
275
star
82

tflite-micro-arduino-examples

C++
207
star
83

kfac

An implementation of KFAC for TensorFlow
Python
197
star
84

ngraph-bridge

TensorFlow-nGraph bridge
C++
137
star
85

profiler-ui

[Deprecated] The TensorFlow Profiler (TFProf) UI provides a visual interface for profiling TensorFlow models.
HTML
134
star
86

tensorboard-plugin-example

Python
134
star
87

tfx-addons

Developers helping developers. TFX-Addons is a collection of community projects to build new components, examples, libraries, and tools for TFX. The projects are organized under the auspices of the special interest group, SIG TFX-Addons. Join the group at http://goo.gle/tfx-addons-group
Jupyter Notebook
125
star
88

metadata

Utilities for passing TensorFlow-related metadata between tools
Python
102
star
89

networking

Enhanced networking support for TensorFlow. Maintained by SIG-networking.
C++
97
star
90

tfhub.dev

Python
75
star
91

java-ndarray

Java
71
star
92

java-models

Models in Java
Java
71
star
93

tfjs-website

WebGL-accelerated ML // linear algebra // automatic differentiation for JavaScript.
CSS
71
star
94

tfjs-data

Simple APIs to load and prepare data for use in machine learning models
TypeScript
66
star
95

tfx-bsl

Common code for TFX
Python
64
star
96

autograph

Python
50
star
97

model-remediation

Model Remediation is a library that provides solutions for machine learning practitioners working to create and train models in a way that reduces or eliminates user harm resulting from underlying performance biases.
Python
42
star
98

codelabs

Jupyter Notebook
36
star
99

tensorstore

C++
25
star
100

swift-bindings

Swift
25
star