• Stars
    star
    7,917
  • Rank 4,438 (Top 0.1 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 4 years ago
  • Updated about 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Trax β€” Deep Learning with Clear Code and Speed

Trax β€” Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues GitHub Build Contributions welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schΓΆn, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)
Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]

More Repositories

1

material-design-icons

Material Design icons by Google
49,605
star
2

guava

Google core libraries for Java
Java
48,313
star
3

zx

A tool for writing better scripts
JavaScript
37,928
star
4

styleguide

Style guides for Google-originated open-source projects
HTML
36,377
star
5

leveldb

LevelDB is a fast key-value storage library written at Google that provides an ordered mapping from string keys to string values.
C++
33,564
star
6

material-design-lite

Material Design Components in HTML/CSS/JS
HTML
32,276
star
7

googletest

GoogleTest - Google Testing and Mocking Framework
C++
32,215
star
8

jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Python
27,471
star
9

python-fire

Python Fire is a library for automatically generating command line interfaces (CLIs) from absolutely any Python object.
Python
26,112
star
10

mediapipe

Cross-platform, customizable ML solutions for live and streaming media.
C++
25,107
star
11

comprehensive-rust

This is the Rust course used by the Android team at Google. It provides you the material to quickly teach Rust.
Rust
24,867
star
12

gson

A Java serialization/deserialization library to convert Java Objects into JSON and back
Java
22,856
star
13

flatbuffers

FlatBuffers: Memory Efficient Serialization Library
C++
21,883
star
14

iosched

The Google I/O Android App
Kotlin
21,801
star
15

ExoPlayer

An extensible media player for Android
Java
21,309
star
16

eng-practices

Google's Engineering Practices documentation
19,715
star
17

web-starter-kit

Web Starter Kit - a workflow for multi-device websites
HTML
18,426
star
18

flexbox-layout

Flexbox for Android
Kotlin
18,141
star
19

fonts

Font files available from Google Fonts, and a public issue tracker for all things Google Fonts
HTML
17,389
star
20

filament

Filament is a real-time physically based rendering engine for Android, iOS, Windows, Linux, macOS, and WebGL2
C++
16,946
star
21

cadvisor

Analyzes resource usage and performance characteristics of running containers.
Go
16,184
star
22

libphonenumber

Google's common Java, C++ and JavaScript library for parsing, formatting, and validating international phone numbers.
C++
15,728
star
23

gvisor

Application Kernel for Containers
Go
14,646
star
24

WebFundamentals

Former git repo for WebFundamentals on developers.google.com
JavaScript
13,848
star
25

yapf

A formatter for Python files
Python
13,560
star
26

tink

Tink is a multi-language, cross-platform, open source library that provides cryptographic APIs that are secure, easy to use correctly, and hard(er) to misuse.
Java
13,318
star
27

deepdream

13,212
star
28

brotli

Brotli compression format
TypeScript
12,921
star
29

guetzli

Perceptual JPEG encoder
C++
12,863
star
30

guice

Guice (pronounced 'juice') is a lightweight dependency injection framework for Java 8 and above, brought to you by Google.
Java
12,324
star
31

wire

Compile-time Dependency Injection for Go
Go
12,131
star
32

blockly

The web-based visual programming editor.
TypeScript
12,033
star
33

grumpy

Grumpy is a Python to Go source code transcompiler and runtime.
Go
10,464
star
34

sanitizers

AddressSanitizer, ThreadSanitizer, MemorySanitizer
C
10,437
star
35

dopamine

Dopamine is a research framework for fast prototyping of reinforcement learning algorithms.
Jupyter Notebook
10,319
star
36

or-tools

Google's Operations Research tools:
C++
10,299
star
37

auto

A collection of source code generators for Java.
Java
10,234
star
38

go-github

Go library for accessing the GitHub v3 API
Go
9,941
star
39

oss-fuzz

OSS-Fuzz - continuous fuzzing for open source software.
Shell
9,324
star
40

go-cloud

The Go Cloud Development Kit (Go CDK): A library and tools for open cloud development in Go.
Go
9,314
star
41

sentencepiece

Unsupervised text tokenizer for Neural Network-based text generation.
C++
8,657
star
42

re2

RE2 is a fast, safe, thread-friendly alternative to backtracking regular expression engines like those used in PCRE, Perl, and Python. It is a C++ library.
C++
8,190
star
43

traceur-compiler

Traceur is a JavaScript.next-to-JavaScript-of-today compiler
JavaScript
8,178
star
44

tsunami-security-scanner

Tsunami is a general purpose network security scanner with an extensible plugin system for detecting high severity vulnerabilities with high confidence.
Java
8,036
star
45

skia

Skia is a complete 2D graphic library for drawing Text, Geometries, and Images.
C++
7,874
star
46

benchmark

A microbenchmark support library
C++
7,812
star
47

android-classyshark

Android and Java bytecode viewer
Java
7,440
star
48

pprof

pprof is a tool for visualization and analysis of profiling data
Go
7,235
star
49

agera

Reactive Programming for Android
Java
7,227
star
50

closure-compiler

A JavaScript checker and optimizer.
Java
7,195
star
51

magika

Detect file content types with deep learning
Python
7,171
star
52

accompanist

A collection of extension libraries for Jetpack Compose
Kotlin
7,157
star
53

flutter-desktop-embedding

Experimental plugins for Flutter for Desktop
C++
7,108
star
54

diff-match-patch

Diff Match Patch is a high-performance library in multiple languages that manipulates plain text.
Python
6,918
star
55

lovefield

Lovefield is a relational database for web apps. Written in JavaScript, works cross-browser. Provides SQL-like APIs that are fast, safe, and easy to use.
JavaScript
6,847
star
56

glog

C++ implementation of the Google logging module
C++
6,748
star
57

jsonnet

Jsonnet - The data templating language
Jsonnet
6,711
star
58

latexify_py

A library to generate LaTeX expression from Python code.
Python
6,708
star
59

error-prone

Catch common Java mistakes as compile-time errors
Java
6,690
star
60

model-viewer

Easily display interactive 3D models on the web and in AR!
TypeScript
6,390
star
61

gops

A tool to list and diagnose Go processes currently running on your system
Go
6,375
star
62

automl

Google Brain AutoML
Jupyter Notebook
6,113
star
63

gopacket

Provides packet processing capabilities for Go
Go
6,067
star
64

physical-web

The Physical Web: walk up and use anything
Java
6,017
star
65

j2objc

A Java to iOS Objective-C translation tool and runtime.
Java
5,975
star
66

grafika

Grafika test app
Java
5,964
star
67

draco

Draco is a library for compressing and decompressing 3D geometric meshes and point clouds. It is intended to improve the storage and transmission of 3D graphics.
C++
5,947
star
68

snappy

A fast compressor/decompressor
C++
5,940
star
69

ios-webkit-debug-proxy

A DevTools proxy (Chrome Remote Debugging Protocol) for iOS devices (Safari Remote Web Inspector).
C
5,848
star
70

osv-scanner

Vulnerability scanner written in Go which uses the data provided by https://osv.dev
Go
5,763
star
71

seesaw

Seesaw v2 is a Linux Virtual Server (LVS) based load balancing platform.
Go
5,586
star
72

seq2seq

A general-purpose encoder-decoder framework for Tensorflow
Python
5,577
star
73

EarlGrey

🍡 iOS UI Automation Test Framework
Objective-C
5,570
star
74

google-java-format

Reformats Java source code to comply with Google Java Style.
Java
5,366
star
75

flax

Flax is a neural network library for JAX that is designed for flexibility.
Python
5,358
star
76

wireit

Wireit upgrades your npm/pnpm/yarn scripts to make them smarter and more efficient.
TypeScript
5,280
star
77

battery-historian

Battery Historian is a tool to analyze battery consumers using Android "bugreport" files.
Go
5,249
star
78

clusterfuzz

Scalable fuzzing infrastructure.
Python
5,170
star
79

bbr

5,156
star
80

gumbo-parser

An HTML5 parsing library in pure C99
HTML
5,141
star
81

git-appraise

Distributed code review system for Git repos
Go
5,090
star
82

google-authenticator

Open source version of Google Authenticator (except the Android app)
Java
5,077
star
83

gemma.cpp

lightweight, standalone C++ inference engine for Google's Gemma models.
C++
5,076
star
84

syzkaller

syzkaller is an unsupervised coverage-guided kernel fuzzer
Go
5,063
star
85

uuid

Go package for UUIDs based on RFC 4122 and DCE 1.1: Authentication and Security Services.
Go
4,942
star
86

gts

β˜‚οΈ TypeScript style guide, formatter, and linter.
TypeScript
4,890
star
87

closure-library

Google's common JavaScript library
JavaScript
4,832
star
88

cameraview

[DEPRECATED] Easily integrate Camera features into your Android app
Java
4,734
star
89

grr

GRR Rapid Response: remote live forensics for incident response
Python
4,627
star
90

liquidfun

2D physics engine for games
C++
4,559
star
91

pytype

A static type analyzer for Python code
Python
4,454
star
92

gxui

An experimental Go cross platform UI library.
Go
4,450
star
93

bloaty

Bloaty: a size profiler for binaries
C++
4,386
star
94

clasp

πŸ”— Command Line Apps Script Projects
TypeScript
4,336
star
95

ko

Build and deploy Go applications on Kubernetes
Go
4,329
star
96

santa

A binary authorization and monitoring system for macOS
Objective-C
4,288
star
97

google-ctf

Google CTF
Go
4,237
star
98

tamperchrome

Tamper Dev is an extension that allows you to intercept and edit HTTP/HTTPS requests and responses as they happen without the need of a proxy. Works across all operating systems (including Chrome OS).
TypeScript
4,137
star
99

end-to-end

End-To-End is a crypto library to encrypt, decrypt, digital sign, and verify signed messages (implementing OpenPGP)
JavaScript
4,126
star
100

orbit

C/C++ Performance Profiler
C++
3,981
star