• Stars
    star
    2,479
  • Rank 17,890 (Top 0.4 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created about 2 years ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Solve puzzles. Improve your pytorch.

Tensor Puzzles

When learning a tensor programming language like PyTorch or Numpy it is tempting to rely on the standard library (or more honestly StackOverflow) to find a magic function for everything. But in practice, the tensor language is extremely expressive, and you can do most things from first principles and clever use of broadcasting.

This is a collection of 16 tensor puzzles. Like chess puzzles these are not meant to simulate the complexity of a real program, but to practice in a simplified environment. Each puzzle asks you to reimplement one function in the NumPy standard library without magic.

I recommend running in Colab. Click here and copy the notebook to get start.

Open In Colab

!pip install -qqq torchtyping hypothesis pytest git+https://github.com/danoneata/chalk@srush-patch-1
!wget -q https://github.com/srush/Tensor-Puzzles/raw/main/lib.py
from lib import draw_examples, make_test, run_test
import torch
import numpy as np
from torchtyping import TensorType as TT
tensor = torch.tensor

Rules

  1. These puzzles are about broadcasting. Know this rule.

  1. Each puzzle needs to be solved in 1 line (<80 columns) of code.

  2. You are allowed @, arithmetic, comparison, shape, any indexing (e.g. a[:j], a[:, None], a[arange(10)]), and previous puzzle functions.

  3. You are not allowed anything else. No view, sum, take, squeeze, tensor.

  4. You can start with these two functions:

def arange(i: int):
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))

draw_examples("arange", [{"" : arange(i)} for i in [5, 3, 9]])

svg

# Example of broadcasting.
examples = [(arange(4), arange(5)[:, None]) ,
            (arange(3)[:, None], arange(2))]
draw_examples("broadcast", [{"a": a, "b":b, "ret": a + b} for a, b in examples])

svg

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

# In diagrams, orange is positive/True, where is zero/False, and blue is negative.

examples = [(tensor([False]), tensor([10]), tensor([0])),
            (tensor([False, True]), tensor([1, 1]), tensor([-10, 0])),
            (tensor([False, True]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([[0], [10]]), tensor([-10, 0])),
           ]
draw_examples("where", [{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples])

svg

Puzzle 1 - ones

Compute ones - the vector of all ones.

def ones_spec(out):
    for i in range(len(out)):
        out[i] = 1
        
def ones(i: int) -> TT["i"]:
    raise NotImplementedError

test_ones = make_test("one", ones, ones_spec, add_sizes=["i"])

svg

# run_test(test_ones)

Puzzle 2 - sum

Compute sum - the sum of a vector.

def sum_spec(a, out):
    out[0] = 0
    for i in range(len(a)):
        out[0] += a[i]
        
def sum(a: TT["i"]) -> TT[1]:
    raise NotImplementedError


test_sum = make_test("sum", sum, sum_spec)

svg

# run_test(test_sum)

Puzzle 3 - outer

Compute outer - the outer product of two vectors.

def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]
            
def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    raise NotImplementedError
    
test_outer = make_test("outer", outer, outer_spec)

svg

# run_test(test_outer)

Puzzle 4 - diag

Compute diag - the diagonal vector of a square matrix.

def diag_spec(a, out):
    for i in range(len(a)):
        out[i] = a[i][i]
        
def diag(a: TT["i", "i"]) -> TT["i"]:
    raise NotImplementedError


test_diag = make_test("diag", diag, diag_spec)

svg

# run_test(test_diag)

Puzzle 5 - eye

Compute eye - the identity matrix.

def eye_spec(out):
    for i in range(len(out)):
        out[i][i] = 1
        
def eye(j: int) -> TT["j", "j"]:
    raise NotImplementedError
    
test_eye = make_test("eye", eye, eye_spec, add_sizes=["j"])

svg

# run_test(test_eye)

Puzzle 6 - triu

Compute triu - the upper triangular matrix.

def triu_spec(out):
    for i in range(len(out)):
        for j in range(len(out)):
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0
                
def triu(j: int) -> TT["j", "j"]:
    raise NotImplementedError


test_triu = make_test("triu", triu, triu_spec, add_sizes=["j"])

svg

# run_test(test_triu)

Puzzle 7 - cumsum

Compute cumsum - the cumulative sum.

def cumsum_spec(a, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: TT["i"]) -> TT["i"]:
    raise NotImplementedError

test_cumsum = make_test("cumsum", cumsum, cumsum_spec)

svg

# run_test(test_cumsum)

Puzzle 8 - diff

Compute diff - the running difference.

def diff_spec(a, out):
    out[0] = a[0]
    for i in range(1, len(out)):
        out[i] = a[i] - a[i - 1]

def diff(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError

test_diff = make_test("diff", diff, diff_spec, add_sizes=["i"])

svg

# run_test(test_diff)

Puzzle 9 - vstack

Compute vstack - the matrix of two vectors

def vstack_spec(a, b, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
    raise NotImplementedError


test_vstack = make_test("vstack", vstack, vstack_spec)

svg

# run_test(test_vstack)

Puzzle 10 - roll

Compute roll - the vector shifted 1 circular position.

def roll_spec(a, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]
            
def roll(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError


test_roll = make_test("roll", roll, roll_spec, add_sizes=["i"])

svg

# run_test(test_roll)

Puzzle 11 - flip

Compute flip - the reversed vector

def flip_spec(a, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]
        
def flip(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError


test_flip = make_test("flip", flip, flip_spec, add_sizes=["i"])

svg

# run_test(test_flip)

Puzzle 12 - compress

Compute compress - keep only masked entries (left-aligned).

def compress_spec(g, v, out):
    j = 0
    for i in range(len(g)):
        if g[i]:
            out[j] = v[i]
            j += 1
            
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    raise NotImplementedError


test_compress = make_test("compress", compress, compress_spec, add_sizes=["i"])

svg

# run_test(test_compress)

Puzzle 13 - pad_to

Compute pad_to - eliminate or add 0s to change size of vector.

def pad_to_spec(a, out):
    for i in range(min(len(out), len(a))):
        out[i] = a[i]


def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
    raise NotImplementedError


test_pad_to = make_test("pad_to", pad_to, pad_to_spec, add_sizes=["i", "j"])

svg

# run_test(test_pad_to)

Puzzle 14 - sequence_mask

Compute sequence_mask - pad out to length per batch.

def sequence_mask_spec(values, length, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            if j < length[i]:
                out[i][j] = values[i][j]
            else:
                out[i][j] = 0
    
def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    raise NotImplementedError


def constraint_set_length(d):
    d["length"] = d["length"] % d["values"].shape[1]
    return d


test_sequence = make_test("sequence_mask",
    sequence_mask, sequence_mask_spec, constraint=constraint_set_length
)

svg

# run_test(test_sequence)

Puzzle 15 - bincount

Compute bincount - count number of times an entry was seen.

def bincount_spec(a, out):
    for i in range(len(a)):
        out[a[i]] += 1
        
def bincount(a: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError


def constraint_set_max(d):
    d["a"] = d["a"] % d["return"].shape[0]
    return d


test_bincount = make_test("bincount",
    bincount, bincount_spec, add_sizes=["j"], constraint=constraint_set_max
)

svg

# run_test(test_bincount)

Puzzle 16 - scatter_add

Compute scatter_add - add together values that link to the same location.

def scatter_add_spec(values, link, out):
    for j in range(len(values)):
        out[link[j]] += values[j]
        
def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError


def constraint_set_max(d):
    d["link"] = d["link"] % d["return"].shape[0]
    return d


test_scatter_add = make_test("scatter_add",
    scatter_add, scatter_add_spec, add_sizes=["j"], constraint=constraint_set_max
)

svg

# run_test(test_scatter_add)

Puzzle 17 - flatten

Compute flatten

def flatten_spec(a, out):
    k = 0
    for i in range(len(a)):
        for j in range(len(a[0])):
            out[k] = a[i][j]
            k += 1

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
    raise NotImplementedError

test_flatten = make_test("flatten", flatten, flatten_spec, add_sizes=["i", "j"])

svg

# run_test(test_flatten)

Puzzle 18 - linspace

Compute linspace

def linspace_spec(i, j, out):
    for k in range(len(out)):
        out[k] = float(i + (j - i) * k / max(1, len(out) - 1))

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
    raise NotImplementedError

test_linspace = make_test("linspace", linspace, linspace_spec, add_sizes=["n"])

svg

# run_test(test_linspace)

Puzzle 19 - heaviside

Compute heaviside

def heaviside_spec(a, b, out):
    for k in range(len(out)):
        if a[k] == 0:
            out[k] = b[k]
        else:
            out[k] = int(a[k] > 0)

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
    raise NotImplementedError

test_heaviside = make_test("heaviside", heaviside, heaviside_spec)

svg

# run_test(test_heaviside)

Puzzle 20 - repeat (1d)

Compute repeat

def repeat_spec(a, d, out):
    for i in range(d[0]):
        for k in range(len(a)):
            out[i][k] = a[k]

def constraint_set(d):
    d["d"][0] = d["return"].shape[0]
    return d

            
def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
    raise NotImplementedError

test_repeat = make_test("repeat", repeat, repeat_spec, constraint=constraint_set)


# ## Puzzle 21 - bucketize
#
# Compute [bucketize](https://pytorch.org/docs/stable/generated/torch.bucketize.html)

svg

def bucketize_spec(v, boundaries, out):
    for i, val in enumerate(v):
        out[i] = 0
        for j in range(len(boundaries)-1):
            if val >= boundaries[j]:
                out[i] = j + 1
        if val >= boundaries[-1]:
            out[i] = len(boundaries)


def constraint_set(d):
    d["boundaries"] = np.abs(d["boundaries"]).cumsum()
    return d

            
def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
    raise NotImplementedError

test_bucketize = make_test("bucketize", bucketize, bucketize_spec,
                           constraint=constraint_set)


#
# # Speed Run Mode!
#
# What is the smallest you can make each of these?

svg

import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, vstack, roll, flip,
       compress, pad_to, sequence_mask, bincount, scatter_add)

for fn in fns:
    lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]
    
    if len(lines) > 3:
        print(fn.__name__, len(lines[2]), "(more than 1 line)")
    else:
        print(fn.__name__, len(lines[1]))
ones 29
sum 29
outer 29
diag 29
eye 29
triu 29
cumsum 29
diff 29
vstack 29
roll 29
flip 29
compress 29
pad_to 29
sequence_mask 29
bincount 29
scatter_add 29

More Repositories

1

GPU-Puzzles

Solve puzzles. Learn CUDA.
Jupyter Notebook
5,084
star
2

MiniChain

A tiny library for coding with large language models.
Python
1,174
star
3

llama2.rs

A fast llama2 decoder in pure Rust.
Rust
981
star
4

LLM-Training-Puzzles

What would you do with 1000 H100s...
Jupyter Notebook
738
star
5

Triton-Puzzles

Puzzles for learning Triton
Jupyter Notebook
678
star
6

annotated-s4

Implementation of https://srush.github.io/annotated-s4
Python
438
star
7

annotated-mamba

Annotated version of the Mamba paper
Jupyter Notebook
411
star
8

Autodiff-Puzzles

Jupyter Notebook
295
star
9

streambook

Live Python Notebooks with any Editor
Jupyter Notebook
275
star
10

Transformer-Puzzles

Puzzles for exploring transformers
Jupyter Notebook
260
star
11

raspy

An interactive exploration of Transformer programming.
Jupyter Notebook
229
star
12

do-we-need-attention

TeX
159
star
13

parallax

Python
157
star
14

awesome-ml-tracking

102
star
15

GPTWorld

A puzzle to learn about prompting
Jupyter Notebook
98
star
16

triton-autodiff

Experiment of using Tangent to autodiff triton
Python
66
star
17

torch-queue

Python
63
star
18

LLM-Talk

45
star
19

torch-golf

Silly twitter torch implementations.
Python
44
star
20

PyDecode

A dynamic programming toolkit.
C++
39
star
21

VirtualTeaching

DIY setup for virtual teaching on ubuntu
39
star
22

learns-dex

33
star
23

mamba-primer

32
star
24

text2table

Python
31
star
25

jax-lda

Python
31
star
26

ProbTalk

HTML
29
star
27

Hierarchical-Bayes-Compiler

Hal Daume's hbc
Haskell
20
star
28

g9py

HTML
18
star
29

drop7

Jupyter Notebook
18
star
30

mamba-scans

Blog post
16
star
31

Tensor-Puzzles-Penzai

HTML
15
star
32

transformers-bet

HTML
11
star
33

relax-decode

Java
10
star
34

aima-arguments

7
star
35

torch-mechanics

Amateur experiments with autodiff mechanics simulators
7
star
36

minitorch-rust

7
star
37

cs5781

Machine Learning Engineering
6
star
38

postgres-provanence

C
6
star
39

SemiRings

Holder for a bunch of semirings used in ChartParsing
Haskell
6
star
40

MRF-LM

Shell
5
star
41

TextBook

Command-line Facebook
Haskell
5
star
42

PowerEdit

A super-minimal Python-based video editor âš¡
Python
5
star
43

hsNLP-

Combined repo for nlp libs
Haskell
4
star
44

provenance

4
star
45

icfp2009

when I was 4 years old I was maimed by a giant pig
Haskell
4
star
46

configure

some configuration file
Emacs Lisp
3
star
47

clustering

C++
3
star
48

BT-AI

Jupyter Notebook
3
star
49

annotated-transformer.github.io

Annotated Transformer Blog Post
3
star
50

srush-blog

Haskell
2
star
51

Eisner-Parser

An implementation of the Eisner Parser (described in "Bilexical Grammars and a Cubic-time parsing algorithm" ) in Haskell
Haskell
2
star
52

FSM

Finite State Machine lib for haskell
Haskell
2
star
53

hplay

2
star
54

opennmt-gen

Shell
2
star
55

PhraseDep

C++
2
star
56

srush-wiki

2
star
57

tf-fork

Python
2
star
58

icfp2003

Race Car
Haskell
2
star
59

icfp2008

2
star
60

hypergraph

Hypergraph specification
Python
2
star
61

triton

2
star
62

learns-triton

2
star
63

bipartite-sampler

Implementation of Huber-Law rejection sampling for bipartite graphs
C
1
star
64

Training

Haskell
1
star
65

ezTVM

1
star
66

test_grade

Python
1
star
67

Chart-Parsing-

haskell library for basic chart parsers
Haskell
1
star
68

blog-twitter

1
star
69

sigmoidfit

Jupyter Notebook
1
star
70

evernote

Command line bindings for evernote
1
star
71

transforest

transforest
Python
1
star
72

decoding-methods

1
star
73

blog

Jupyter Notebook
1
star
74

nlp-course

Go
1
star
75

beamer-animation

Create animations for LaTeX Beamer presentations.
Python
1
star
76

Duel

Python
1
star
77

nlp

1
star
78

twitter-simmons-sports

1
star
79

monadnack-project

Art project for monadnack
1
star
80

peoplesounds

Python
1
star
81

osgai

JavaScript
1
star
82

CutParse

C++
1
star
83

Penn-Treebank

Haskell library for the penn treebank management
1
star
84

Lattice

lattice protobuffer
Python
1
star
85

icfp2020

Python
1
star
86

twittersports

1
star
87

ProbDist

Tools for probabality distributions focusing on estimation, conditioning, and smoothing
Haskell
1
star