• Stars
    star
    8,051
  • Rank 4,627 (Top 0.1 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 5 years ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues GitHub Build Contributions welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)
Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]

More Repositories

1

material-design-icons

Material Design icons by Google (Material Symbols)
50,560
star
2

guava

Google core libraries for Java
Java
48,313
star
3

zx

A tool for writing better scripts
JavaScript
42,760
star
4

styleguide

Style guides for Google-originated open-source projects
HTML
37,420
star
5

leveldb

LevelDB is a fast key-value storage library written at Google that provides an ordered mapping from string keys to string values.
C++
36,205
star
6

googletest

GoogleTest - Google Testing and Mocking Framework
C++
34,040
star
7

material-design-lite

Material Design Components in HTML/CSS/JS
HTML
32,281
star
8

comprehensive-rust

This is the Rust course used by the Android team at Google. It provides you the material to quickly teach Rust.
Rust
27,842
star
9

python-fire

Python Fire is a library for automatically generating command line interfaces (CLIs) from absolutely any Python object.
Python
26,842
star
10

mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
C++
25,626
star
11

gson

A Java serialization/deserialization library to convert Java Objects into JSON and back
Java
23,317
star
12

flatbuffers

FlatBuffers: Memory Efficient Serialization Library
C++
23,037
star
13

iosched

The Google I/O Android App
Kotlin
21,772
star
14

ExoPlayer

This project is deprecated and stale. The latest ExoPlayer code is available in https://github.com/androidx/media
Java
21,710
star
15

eng-practices

Google's Engineering Practices documentation
19,942
star
16

web-starter-kit

Web Starter Kit - a workflow for multi-device websites
HTML
18,422
star
17

flexbox-layout

Flexbox for Android
Kotlin
18,230
star
18

fonts

Font files available from Google Fonts, and a public issue tracker for all things Google Fonts
HTML
18,222
star
19

filament

Filament is a real-time physically based rendering engine for Android, iOS, Windows, Linux, macOS, and WebGL2
C++
17,554
star
20

cadvisor

Analyzes resource usage and performance characteristics of running containers.
Go
17,078
star
21

gvisor

Application Kernel for Containers
Go
15,733
star
22

libphonenumber

Google's common Java, C++ and JavaScript library for parsing, formatting, and validating international phone numbers.
C++
15,728
star
23

WebFundamentals

Former git repo for WebFundamentals on developers.google.com
JavaScript
13,851
star
24

yapf

A formatter for Python files
Python
13,755
star
25

brotli

Brotli compression format
TypeScript
13,363
star
26

tink

Tink is a multi-language, cross-platform, open source library that provides cryptographic APIs that are secure, easy to use correctly, and hard(er) to misuse.
Java
13,318
star
27

deepdream

13,212
star
28

wire

Compile-time Dependency Injection for Go
Go
12,919
star
29

guetzli

Perceptual JPEG encoder
C++
12,917
star
30

guice

Guice (pronounced 'juice') is a lightweight dependency injection framework for Java 11 and above, brought to you by Google.
Java
12,458
star
31

blockly

The web-based visual programming editor.
TypeScript
12,392
star
32

sanitizers

AddressSanitizer, ThreadSanitizer, MemorySanitizer
C
11,410
star
33

or-tools

Google's Operations Research tools:
C++
11,144
star
34

dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
Jupyter Notebook
10,529
star
35

grumpy

Grumpy is a Python to Go source code transcompiler and runtime.
Go
10,464
star
36

oss-fuzz

OSS-Fuzz - continuous fuzzing for open source software.
Shell
10,389
star
37

auto

A collection of source code generators for Java.
Java
10,234
star
38

go-github

Go library for accessing the GitHub v3 API
Go
10,206
star
39

go-cloud

The Go Cloud Development Kit (Go CDK): A library and tools for open cloud development in Go.
Go
9,546
star
40

sentencepiece

Unsupervised text tokenizer for Neural Network-based text generation.
C++
8,657
star
41

tsunami-security-scanner

Tsunami is a general purpose network security scanner with an extensible plugin system for detecting high severity vulnerabilities with high confidence.
Java
8,232
star
42

re2

RE2 is a fast, safe, thread-friendly alternative to backtracking regular expression engines like those used in PCRE, Perl, and Python. It is a C++ library.
C++
8,190
star
43

traceur-compiler

Traceur is a JavaScript.next-to-JavaScript-of-today compiler
JavaScript
8,173
star
44

pprof

pprof is a tool for visualization and analysis of profiling data
Go
7,875
star
45

skia

Skia is a complete 2D graphic library for drawing Text, Geometries, and Images.
C++
7,874
star
46

benchmark

A microbenchmark support library
C++
7,812
star
47

magika

Detect file content types with deep learning
Rust
7,680
star
48

android-classyshark

Android and Java bytecode viewer
Java
7,492
star
49

accompanist

A collection of extension libraries for Jetpack Compose
Kotlin
7,442
star
50

closure-compiler

A JavaScript checker and optimizer.
Java
7,394
star
51

agera

Reactive Programming for Android
Java
7,227
star
52

latexify_py

A library to generate LaTeX expression from Python code.
Python
7,160
star
53

diff-match-patch

Diff Match Patch is a high-performance library in multiple languages that manipulates plain text.
Python
7,132
star
54

flutter-desktop-embedding

Experimental plugins for Flutter for Desktop
C++
7,102
star
55

glog

C++ implementation of the Google logging module
C++
7,017
star
56

jsonnet

Jsonnet - The data templating language
Jsonnet
6,938
star
57

model-viewer

Easily display interactive 3D models on the web and in AR!
TypeScript
6,858
star
58

lovefield

Lovefield is a relational database for web apps. Written in JavaScript, works cross-browser. Provides SQL-like APIs that are fast, safe, and easy to use.
JavaScript
6,847
star
59

error-prone

Catch common Java mistakes as compile-time errors
Java
6,818
star
60

draco

Draco is a library for compressing and decompressing 3D geometric meshes and point clouds. It is intended to improve the storage and transmission of 3D graphics.
C++
6,459
star
61

gops

A tool to list and diagnose Go processes currently running on your system
Go
6,375
star
62

gopacket

Provides packet processing capabilities for Go
Go
6,289
star
63

automl

Google Brain AutoML
Jupyter Notebook
6,230
star
64

osv-scanner

Vulnerability scanner written in Go which uses the data provided by https://osv.dev
Go
6,222
star
65

flax

Flax is a neural network library for JAX that is designed for flexibility.
Jupyter Notebook
6,085
star
66

grafika

Grafika test app
Java
6,071
star
67

snappy

A fast compressor/decompressor
C++
6,068
star
68

physical-web

The Physical Web: walk up and use anything
Java
6,017
star
69

j2objc

A Java to iOS Objective-C translation tool and runtime.
Java
5,990
star
70

gemma.cpp

lightweight, standalone C++ inference engine for Google's Gemma models.
C++
5,961
star
71

ios-webkit-debug-proxy

A DevTools proxy (Chrome Remote Debugging Protocol) for iOS devices (Safari Remote Web Inspector).
C
5,918
star
72

seesaw

Seesaw v2 is a Linux Virtual Server (LVS) based load balancing platform.
Go
5,634
star
73

EarlGrey

🍵 iOS UI Automation Test Framework
Objective-C
5,616
star
74

seq2seq

A general-purpose encoder-decoder framework for Tensorflow
Python
5,577
star
75

google-java-format

Reformats Java source code to comply with Google Java Style.
Java
5,538
star
76

mesop

Rapidly build AI apps in Python
Python
5,401
star
77

wireit

Wireit upgrades your npm/pnpm/yarn scripts to make them smarter and more efficient.
TypeScript
5,385
star
78

syzkaller

syzkaller is an unsupervised coverage-guided kernel fuzzer
Go
5,350
star
79

uuid

Go package for UUIDs based on RFC 4122 and DCE 1.1: Authentication and Security Services.
Go
5,284
star
80

clusterfuzz

Scalable fuzzing infrastructure.
Python
5,283
star
81

battery-historian

Battery Historian is a tool to analyze battery consumers using Android "bugreport" files.
Go
5,249
star
82

gemma_pytorch

The official PyTorch implementation of Google's Gemma models
Python
5,242
star
83

bbr

5,156
star
84

gumbo-parser

An HTML5 parsing library in pure C99
HTML
5,141
star
85

git-appraise

Distributed code review system for Git repos
Go
5,122
star
86

google-authenticator

Open source version of Google Authenticator (except the Android app)
Java
5,077
star
87

gts

☂️ TypeScript style guide, formatter, and linter.
TypeScript
5,071
star
88

closure-library

Google's common JavaScript library
JavaScript
4,881
star
89

grr

GRR Rapid Response: remote live forensics for incident response
Python
4,757
star
90

cameraview

[DEPRECATED] Easily integrate Camera features into your Android app
Java
4,734
star
91

pytype

A static type analyzer for Python code
Python
4,731
star
92

liquidfun

2D physics engine for games
C++
4,559
star
93

clasp

🔗 Command Line Apps Script Projects
TypeScript
4,525
star
94

google-ctf

Google CTF
Python
4,477
star
95

gxui

An experimental Go cross platform UI library.
Go
4,450
star
96

santa

A binary authorization and monitoring system for macOS
Objective-C++
4,402
star
97

bloaty

Bloaty: a size profiler for binaries
C++
4,386
star
98

tcmalloc

C++
4,339
star
99

ko

Build and deploy Go applications on Kubernetes
Go
4,329
star
100

orbit

C/C++ Performance Profiler
C++
4,190
star