• Stars
    star
    157
  • Rank 230,985 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created about 4 years ago
  • Updated almost 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Parallax - Immutable Torch Modules for JAX

Parallax is a prototype for a pure module system for JAX implemented by Sabrina Mielke (@sjmielke) and Sasha Rush (@srush).

Main ideas:

  • Make param modules immutable trees.
  • Replace all imperative style coding and init.
  • Avoid tracking state for most applications by first distributing seeds / globals through tree.
from parallax import Module, Parameter, ParamInit

class Dense(Module):
    # All parameter-holders are explicitly declared.
    weight : Parameter
    bias : Parameter

    # Setup replace __init__ and creates shapes and binds lazy initializers.
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = ParamInit((out_size, in_size), init.xavier_normal())
        self.bias = ParamInit((out_size,), init.normal())


    # Forward is just like standard pytorch.
    def forward(self, input):
        return self.weight @ input + self.bias

    # Hook for pretty printing
    def extra_repr(self):
        return "%d, %d"%(self.weight.shape[1], self.weight.shape[0])

class Dropout(Module):
    # Arbitrary constants allowed.
    rate : float
    def __init__(self, rate):
        super().__init__()
        self.rate = rate

    def forward(self, input):
        # RNG state is use-once or split. Attached to tree.
        state = self.rng

        if self.mode == "train":
            keep = jax.random.bernoulli(state, self.rate, input.shape)
            return jax.numpy.where(keep, input / self.rate, 0)
        else:
            return input

class BinaryNetwork(Module):
    # No difference between modules and parameters
    dense1 : Dense
    dense2 : Dense
    dense3 : Dense
    dropout : Dropout

    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.dense1 = Dense(input_size, hidden_size)
        self.dense2 = Dense(hidden_size, hidden_size)
        self.dense3 = Dense(hidden_size, 1)
        self.dropout = Dropout(0.2)

    def forward(self, input):

        # Standard usage works out of the box.
        x = jax.numpy.tanh(self.dense1(input))

        # Stochastic modules (have random seed already)
        x = self.dropout(x)

        # Shared params / recurrence only requires split to change RNG
        x = jax.numpy.tanh(self.dense2(x))
        x = jax.numpy.tanh(self.dense2(x))

        return jax.nn.sigmoid(self.dense3(jax.numpy.tanh(x)))[0]

# Setup param tree -> declarative, immutable
layer = BinaryNetwork(5, 10)
print(layer)
print(layer.dense1)

# Initialize parameters -> stateful, hidden
rng = jax.random.PRNGKey(0)
layer = layer.initialized(rng)
print(layer)
print(layer.dense1)

initial_loss = None
for i in range(10):
    # Thread state through parameters -> functor, hidden
    rng, iter_rng = jax.random.split(rng)
    layer = layer.new_state(iter_rng, mode="train")
    
    # Jax style grad compute -> tree-shaped immutable
    x = jax.numpy.zeros(5)
    loss = layer(x)
    if initial_loss is None:
        initial_loss = loss
    print(loss)
    grad = layer.grad(x)
    
    # Grad Update -> tree-shaped
    layer = jax.tree_util.tree_multimap(lambda p, g: p - 0.3 * g, layer, grad)

More Repositories

1

GPU-Puzzles

Solve puzzles. Learn CUDA.
Jupyter Notebook
5,084
star
2

Tensor-Puzzles

Solve puzzles. Improve your pytorch.
Jupyter Notebook
2,479
star
3

MiniChain

A tiny library for coding with large language models.
Python
1,174
star
4

llama2.rs

A fast llama2 decoder in pure Rust.
Rust
981
star
5

LLM-Training-Puzzles

What would you do with 1000 H100s...
Jupyter Notebook
738
star
6

Triton-Puzzles

Puzzles for learning Triton
Jupyter Notebook
678
star
7

annotated-s4

Implementation of https://srush.github.io/annotated-s4
Python
438
star
8

annotated-mamba

Annotated version of the Mamba paper
Jupyter Notebook
411
star
9

Autodiff-Puzzles

Jupyter Notebook
295
star
10

streambook

Live Python Notebooks with any Editor
Jupyter Notebook
275
star
11

Transformer-Puzzles

Puzzles for exploring transformers
Jupyter Notebook
260
star
12

raspy

An interactive exploration of Transformer programming.
Jupyter Notebook
229
star
13

do-we-need-attention

TeX
159
star
14

awesome-ml-tracking

102
star
15

GPTWorld

A puzzle to learn about prompting
Jupyter Notebook
98
star
16

triton-autodiff

Experiment of using Tangent to autodiff triton
Python
66
star
17

torch-queue

Python
63
star
18

LLM-Talk

45
star
19

torch-golf

Silly twitter torch implementations.
Python
44
star
20

PyDecode

A dynamic programming toolkit.
C++
39
star
21

VirtualTeaching

DIY setup for virtual teaching on ubuntu
39
star
22

learns-dex

33
star
23

mamba-primer

32
star
24

text2table

Python
31
star
25

jax-lda

Python
31
star
26

ProbTalk

HTML
29
star
27

Hierarchical-Bayes-Compiler

Hal Daume's hbc
Haskell
20
star
28

g9py

HTML
18
star
29

drop7

Jupyter Notebook
18
star
30

mamba-scans

Blog post
16
star
31

Tensor-Puzzles-Penzai

HTML
15
star
32

transformers-bet

HTML
11
star
33

relax-decode

Java
10
star
34

aima-arguments

7
star
35

torch-mechanics

Amateur experiments with autodiff mechanics simulators
7
star
36

minitorch-rust

7
star
37

cs5781

Machine Learning Engineering
6
star
38

postgres-provanence

C
6
star
39

SemiRings

Holder for a bunch of semirings used in ChartParsing
Haskell
6
star
40

MRF-LM

Shell
5
star
41

TextBook

Command-line Facebook
Haskell
5
star
42

PowerEdit

A super-minimal Python-based video editor âš¡
Python
5
star
43

hsNLP-

Combined repo for nlp libs
Haskell
4
star
44

provenance

4
star
45

icfp2009

when I was 4 years old I was maimed by a giant pig
Haskell
4
star
46

configure

some configuration file
Emacs Lisp
3
star
47

clustering

C++
3
star
48

BT-AI

Jupyter Notebook
3
star
49

annotated-transformer.github.io

Annotated Transformer Blog Post
3
star
50

srush-blog

Haskell
2
star
51

Eisner-Parser

An implementation of the Eisner Parser (described in "Bilexical Grammars and a Cubic-time parsing algorithm" ) in Haskell
Haskell
2
star
52

FSM

Finite State Machine lib for haskell
Haskell
2
star
53

hplay

2
star
54

opennmt-gen

Shell
2
star
55

PhraseDep

C++
2
star
56

srush-wiki

2
star
57

tf-fork

Python
2
star
58

icfp2003

Race Car
Haskell
2
star
59

icfp2008

2
star
60

hypergraph

Hypergraph specification
Python
2
star
61

triton

2
star
62

learns-triton

2
star
63

bipartite-sampler

Implementation of Huber-Law rejection sampling for bipartite graphs
C
1
star
64

Training

Haskell
1
star
65

ezTVM

1
star
66

test_grade

Python
1
star
67

Chart-Parsing-

haskell library for basic chart parsers
Haskell
1
star
68

blog-twitter

1
star
69

sigmoidfit

Jupyter Notebook
1
star
70

evernote

Command line bindings for evernote
1
star
71

transforest

transforest
Python
1
star
72

decoding-methods

1
star
73

blog

Jupyter Notebook
1
star
74

nlp-course

Go
1
star
75

beamer-animation

Create animations for LaTeX Beamer presentations.
Python
1
star
76

Duel

Python
1
star
77

nlp

1
star
78

twitter-simmons-sports

1
star
79

monadnack-project

Art project for monadnack
1
star
80

peoplesounds

Python
1
star
81

osgai

JavaScript
1
star
82

CutParse

C++
1
star
83

Penn-Treebank

Haskell library for the penn treebank management
1
star
84

Lattice

lattice protobuffer
Python
1
star
85

icfp2020

Python
1
star
86

twittersports

1
star
87

ProbDist

Tools for probabality distributions focusing on estimation, conditioning, and smoothing
Haskell
1
star