• Stars
    star
    1,363
  • Rank 33,267 (Top 0.7 %)
  • Language
    Jupyter Notebook
  • License
    BSD 3-Clause "New...
  • Created about 3 years ago
  • Updated about 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

functorch is JAX-like composable function transforms for PyTorch.

functorch

Why functorch? | Install guide | Transformations | Documentation | Future Plans

This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.

functorch is JAX-like composable function transforms for PyTorch.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

Composing vmap, grad, vjp, and jvp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Install

There are two ways to install functorch:

  1. functorch from source
  2. functorch beta (compatible with recent PyTorch releases)

We recommend trying out the functorch beta first.

Installing functorch from source

Click to expand

Using Colab

Follow the instructions in this Colab notebook

Locally

As of 9/21/2022, functorch comes installed alongside a nightly PyTorch binary. Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ for instructions.

Once you've done that, run a quick sanity check in Python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

functorch development setup

As of 9/21/2022, functorch comes installed alongside PyTorch and is in the PyTorch source tree. Please install PyTorch from source, then, you will be able to import functorch.

Try to run some tests to make sure all is OK:

pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v

AOTAutograd has some additional optional requirements. You can install them via:

pip install networkx

To run functorch tests, please install our test dependencies (expecttest, pyyaml).

Installing functorch beta (compatible with recent PyTorch releases)

Click to expand

Using Colab

Follow the instructions here

pip

Prerequisite: Install PyTorch

pip install functorch

Finally, run a quick sanity check in python:

import torch
from functorch import vmap
x = torch.randn(3)
y = vmap(torch.sin)(x)
assert torch.allclose(y, x.sin())

What are the transforms?

Right now, we support the following transforms:

  • grad, vjp, jvp,
  • jacrev, jacfwd, hessian
  • vmap

Furthermore, we have some utilities for working with PyTorch modules.

  • make_functional(model)
  • make_functional_with_buffers(model)

vmap

Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.

vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a new function that maps func over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience:

from functorch import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)

def model(feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)

grad

grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0].

from functorch import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())

# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

from functorch import vmap
batch_size, feature_size = 3, 5

def model(weights,feature_vec):
    # Very simple linear model with activation
    assert feature_vec.dim() == 1
    return feature_vec.dot(weights).relu()

def compute_loss(weights, example, target):
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

vjp

The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors.

from functorch import vjp
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

jvp

The jvp transforms computes Jacobian-vector-products and is also known as "forward-mode AD". It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs) as well as the jvps.

from functorch import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(output, x + y)

jacrev, jacfwd, and hessian

The jacrev transform returns a new function that takes in x and returns the Jacobian of torch.sin with respect to x using reverse-mode AD.

from functorch import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians:

x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)

jacfwd is a drop-in replacement for jacrev that computes Jacobians using forward-mode AD:

from functorch import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

Composing jacrev with itself or jacfwd can produce hessians:

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)

The hessian is a convenience function that combines jacfwd and jacrev:

from functorch import hessian

def f(x):
  return x.sin().sum()

x = torch.randn(5)
hess = hessian(f)(x)

Tracing through the transformations

We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!).

from functorch import make_fx, grad
def f(x):
    return torch.sin(x).sum()
x = torch.randn(100)
grad_f = make_fx(grad(f))(x)
print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul

Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:

  • model ensembling, where all of your weights and buffers have an additional dimension
  • per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.

  • make_functional(model) returns a functional version of model and the model.parameters()
  • make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers().

Here's an example where we compute per-sample-gradients using an nn.Linear layer:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

If you're making an ensemble of models, you may find combine_state_for_ensemble useful.

Documentation

For more documentation, see our docs website.

Debugging

torch._C._functorch.dump_tensor: Dumps dispatch keys on stack torch._C._functorch._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you.

Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out.

License

Functorch has a BSD-style license, as found in the LICENSE file.

Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/pytorch/functorch}},
  year =         {2021}
}

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
78,312
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
21,700
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,495
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
7,713
star
5

captum

Model interpretability and understanding for PyTorch
Python
4,482
star
6

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,443
star
7

serve

Serve, optimize and scale PyTorch models in production
Java
3,969
star
8

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,426
star
9

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,340
star
10

glow

Compiler for Neural Network hardware accelerators
C++
3,116
star
11

torchtune

A Native-PyTorch Library for LLM Fine-tuning
Python
2,946
star
12

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
2,920
star
13

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,355
star
14

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,340
star
15

xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
C++
2,301
star
16

rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Python
1,768
star
17

torchrec

Pytorch domain library for recommendation systems
Python
1,683
star
18

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,606
star
19

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,582
star
20

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,506
star
21

android-demo-app

PyTorch android examples of usage in applications
Java
1,392
star
22

hub

Submission to https://pytorch.org/hub/
Python
1,360
star
23

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,059
star
24

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,050
star
25

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
945
star
26

extension-cpp

C++ extensions in PyTorch
Python
924
star
27

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
913
star
28

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
891
star
29

translate

Translate - a PyTorch Language Library
Python
811
star
30

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
759
star
31

elastic

PyTorch elastic training
Python
725
star
32

torcharrow

High performance model preprocessing library on PyTorch
Python
625
star
33

ios-demo-app

PyTorch iOS examples
Swift
578
star
34

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
578
star
35

tensordict

TensorDict is a pytorch dedicated tensor container.
Python
577
star
36

PiPPy

Pipeline Parallelism for PyTorch
Python
538
star
37

tvm

TVM integration into PyTorch
C++
450
star
38

contrib

Implementations of ideas from recent papers
Python
388
star
39

ort

Accelerate PyTorch models with ONNX Runtime
Python
346
star
40

builder

Continuous builder and binary build scripts for pytorch
Shell
319
star
41

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
318
star
42

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
284
star
43

extension-ffi

Examples of C extensions for PyTorch
Python
254
star
44

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
251
star
45

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
237
star
46

pytorch.github.io

The website for PyTorch
HTML
211
star
47

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
197
star
48

cppdocs

PyTorch C++ API Documentation
HTML
186
star
49

torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Python
177
star
50

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
172
star
51

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
164
star
52

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
125
star
53

java-demo

Jupyter Notebook
119
star
54

rfcs

PyTorch RFCs (experimental)
110
star
55

torchdistx

Torch Distributed Experimental
Python
109
star
56

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
109
star
57

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
97
star
58

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
91
star
59

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
61
star
60

maskedtensor

MaskedTensors for PyTorch
Python
38
star
61

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
62

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
63

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
10
star
64

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
9
star
65

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
5
star
66

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
67

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
68

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
69

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star