• Stars
    star
    451
  • Rank 96,968 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of https://srush.github.io/annotated-s4

Experiments

MNIST Sequence Modeling

python -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=128 model.d_model=128 model.layer.N=64

The following command uses a larger model (5M params) and logs generated samples to wandb every epoch. It achieves 0.36 test NLL (0.52 bits per dimension), a state-of-the-art on this task.

python -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=50 train.lr=5e-3 train.lr_schedule=true model.d_model=512 model.n_layers=6 model.dropout=0.0 train.weight_decay=0.05 model.prenorm=true model.embedding=true wandb.mode=online train.sample=308 

QuickDraw Sequence Modeling

# Default arguments
python -m s4.train dataset=quickdraw layer=s4 train.epochs=10 train.bsz=128 model.d_model=128 model.layer.N=64

# "Run in a day" variant
python -m s4.train dataset=quickdraw layer=s4 train.epochs=1 train.bsz=512 model.d_model=256 model.layer.N=64 model.dropout=0.05

MNIST Classification

python -m s4.train dataset=mnist-classification layer=s4 train.epochs=20 train.bsz=128 model.d_model=128 model.dropout=0.25 train.lr=5e-3 train.lr_schedule=true seed=1

Gets "best" 99.55% accuracy after 20 epochs @ 17s/epoch on an A100

CIFAR-10 Classification

python -m s4.train dataset=cifar-classification layer={s4,dss,s4d} train.epochs=100 train.bsz=50 model.n_layers=6 model.d_model=512 model.dropout=0.25 train.lr=5e-3 train.weight_decay=0.01 train.lr_schedule=true seed=1

S4 gets "best" 91.23% accuracy after 100 epochs @ 2m16s/epoch on an A100

DSS gets "best" 89.31% accuracy after 100 epochs @ 1m41s/epoch on an A100

S4D gets "best" 89.76% accuracy after 100 epochs @ 1m32s/epoch on an A100

The alternative S4D-Lin initialization performs slightly better with 90.98% accuracy.

python -m s4.train dataset=cifar-classification layer=s4d train.epochs=100 train.bsz=50 model.n_layers=6 model.d_model=512 model.dropout=0.25 train.lr=5e-3 train.weight_decay=0.01 train.lr_schedule=true seed=1 +model.layer.scaling=linear

Quickstart (Development)

We have two requirements.txt files that hold dependencies for the current project: one that is tailored to CPUs, the other that installs for GPU.

CPU-Only (MacOS, Linux)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-cpu.txt

# Set up pre-commit
pre-commit install

GPU (CUDA > 11 & CUDNN > 8.2)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-gpu.txt

# Set up pre-commit
pre-commit install

Dependencies from Scratch

In case the above requirements.txt don't work, here are the commands used to download dependencies.

CPU-Only

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install --upgrade "jax[cpu]"
pip install flax
pip install torch torchvision torchaudio

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install

GPU (CUDA > 11, CUDNN > 8.2)

Note - CUDNN > 8.2 is critical for compilation without warnings, and GPU w/ at least Turing architecture for full efficiency.

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install flax
pip install torch==1.10.1+cpu torchvision==0.11.2+cpu torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install

More Repositories

1

GPU-Puzzles

Solve puzzles. Learn CUDA.
Jupyter Notebook
5,464
star
2

Tensor-Puzzles

Solve puzzles. Improve your pytorch.
Jupyter Notebook
2,967
star
3

MiniChain

A tiny library for coding with large language models.
Python
1,197
star
4

llama2.rs

A fast llama2 decoder in pure Rust.
Rust
998
star
5

Triton-Puzzles

Puzzles for learning Triton
Jupyter Notebook
895
star
6

LLM-Training-Puzzles

What would you do with 1000 H100s...
Jupyter Notebook
797
star
7

annotated-mamba

Annotated version of the Mamba paper
Jupyter Notebook
438
star
8

Autodiff-Puzzles

Jupyter Notebook
317
star
9

Transformer-Puzzles

Puzzles for exploring transformers
Jupyter Notebook
281
star
10

streambook

Live Python Notebooks with any Editor
Jupyter Notebook
277
star
11

raspy

An interactive exploration of Transformer programming.
Jupyter Notebook
239
star
12

do-we-need-attention

TeX
159
star
13

parallax

Python
157
star
14

awesome-o1

TeX
119
star
15

GPTWorld

A puzzle to learn about prompting
Jupyter Notebook
104
star
16

awesome-ml-tracking

102
star
17

triton-autodiff

Experiment of using Tangent to autodiff triton
Python
66
star
18

torch-queue

Python
64
star
19

LLM-Talk

45
star
20

torch-golf

Silly twitter torch implementations.
Python
44
star
21

PyDecode

A dynamic programming toolkit.
C++
39
star
22

VirtualTeaching

DIY setup for virtual teaching on ubuntu
39
star
23

mamba-primer

34
star
24

learns-dex

33
star
25

text2table

Python
31
star
26

jax-lda

Python
31
star
27

ProbTalk

HTML
29
star
28

Hierarchical-Bayes-Compiler

Hal Daume's hbc
Haskell
20
star
29

g9py

HTML
18
star
30

drop7

Jupyter Notebook
18
star
31

Tensor-Puzzles-Penzai

HTML
17
star
32

mamba-scans

Blog post
16
star
33

anynp

Proof-of-concept of global switching between numpy/jax/pytorch in a library.
Python
16
star
34

transformers-bet

HTML
12
star
35

relax-decode

Java
10
star
36

aima-arguments

7
star
37

torch-mechanics

Amateur experiments with autodiff mechanics simulators
7
star
38

minitorch-rust

7
star
39

cs5781

Machine Learning Engineering
6
star
40

postgres-provanence

C
6
star
41

PowerEdit

A super-minimal Python-based video editor ⚑
Python
6
star
42

SemiRings

Holder for a bunch of semirings used in ChartParsing
Haskell
6
star
43

DiffRast

HTML
6
star
44

MRF-LM

Shell
5
star
45

TextBook

Command-line Facebook
Haskell
5
star
46

hsNLP-

Combined repo for nlp libs
Haskell
4
star
47

provenance

4
star
48

icfp2009

when I was 4 years old I was maimed by a giant pig
Haskell
4
star
49

configure

some configuration file
Emacs Lisp
3
star
50

clustering

C++
3
star
51

annotated-transformer.github.io

Annotated Transformer Blog Post
3
star
52

BT-AI

Jupyter Notebook
3
star
53

srush-blog

Haskell
2
star
54

Eisner-Parser

An implementation of the Eisner Parser (described in "Bilexical Grammars and a Cubic-time parsing algorithm" ) in Haskell
Haskell
2
star
55

FSM

Finite State Machine lib for haskell
Haskell
2
star
56

hplay

2
star
57

opennmt-gen

Shell
2
star
58

PhraseDep

C++
2
star
59

triton

2
star
60

tf-fork

Python
2
star
61

srush-wiki

2
star
62

icfp2003

Race Car
Haskell
2
star
63

icfp2008

2
star
64

hypergraph

Hypergraph specification
Python
2
star
65

learns-triton

2
star
66

Training

Haskell
1
star
67

bipartite-sampler

Implementation of Huber-Law rejection sampling for bipartite graphs
C
1
star
68

ezTVM

1
star
69

test_grade

Python
1
star
70

Chart-Parsing-

haskell library for basic chart parsers
Haskell
1
star
71

blog-twitter

1
star
72

sigmoidfit

Jupyter Notebook
1
star
73

evernote

Command line bindings for evernote
1
star
74

prof8

Experimental paper writing linter.
TeX
1
star
75

transforest

transforest
Python
1
star
76

decoding-methods

1
star
77

blog

Jupyter Notebook
1
star
78

nlp-course

Go
1
star
79

beamer-animation

Create animations for LaTeX Beamer presentations.
Python
1
star
80

Duel

Python
1
star
81

nlp

1
star
82

twitter-simmons-sports

1
star
83

monadnack-project

Art project for monadnack
1
star
84

peoplesounds

Python
1
star
85

CutParse

C++
1
star
86

Lattice

lattice protobuffer
Python
1
star
87

Penn-Treebank

Haskell library for the penn treebank management
1
star
88

icfp2020

Python
1
star
89

twittersports

1
star
90

ProbDist

Tools for probabality distributions focusing on estimation, conditioning, and smoothing
Haskell
1
star
91

osgai

JavaScript
1
star