• Stars
    star
    577
  • Rank 74,642 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 1 year ago
  • Updated about 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

TensorDict is a pytorch dedicated tensor container.

Docs - GitHub.io Benchmarks Python version GitHub license pypi version pypi nightly version Downloads Downloads codecov circleci Conda - Platform Conda (channel only)

TensorDict

Installation | General features | Tensor-like features | Distributed capabilities | TensorDict for functional programming using FuncTorch | Lazy preallocation | Nesting TensorDicts | TensorClass

TensorDict is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device or point-to-point communication in distributed settings.

The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:

for i, tensordict in enumerate(dataset):
    # the model reads and writes tensordicts
    tensordict = model(tensordict)
    loss = loss_module(tensordict)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Features

General

A tensordict is primarily defined by its batch_size (or shape) and its key-value pairs:

>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])

The batch_size and the first dimensions of each of the tensors must be compliant. The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to live on a dedicated device, which will send each tensor that is written there:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")

Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing a TensorDict is another TensorDict containing tensors indexed along the required dimension:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])

Similarly, one can build tensordicts by stacking or concatenating single tensordicts:

>>> tensordicts = [TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4]) for _ in range(2)]
>>> stack_tensordict = torch.stack(tensordicts, 1)
>>> assert stack_tensordict.shape == torch.Size([3, 2, 4])
>>> assert stack_tensordict["key 1"].shape == torch.Size([3, 2, 4, 5])
>>> cat_tensordict = torch.cat(tensordicts, 0)
>>> assert cat_tensordict.shape == torch.Size([6, 4])
>>> assert cat_tensordict["key 1"].shape == torch.Size([6, 4, 5])

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
torch.Size([3, 4, 1])

One can also send tensordict from device to device, place them in shared memory, clone them, update them in-place or not, split them, unbind them, expand them etc.

If a functionality is missing, it is easy to call it using apply() or apply_():

tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())

Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings. tensordict solves that problem with synchronous and asynchronous helper methods such as recv, irecv, send and isend that behave like their torch.distributed counterparts:

>>> # on all workers
>>> data = TensorDict({"a": torch.zeros(()), ("b", "c"): torch.ones(())}, [])
>>> # on worker 1
>>> data.isend(dst=0)
>>> # on worker 0
>>> data.irecv(src=1)

When nodes share a common scratch space, the MemmapTensor backend can be used to seamlessly send, receive and read a huge amount of data.

TensorDict for functional programming using FuncTorch

We also provide an API to use TensorDict in conjunction with FuncTorch. For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:

>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params)  # params is the last arg (or kwarg)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])

Moreover, tensordict modules are compatible with torch.fx and torch.compile, which means that you can get the best of both worlds: a codebase that is both readable and future-proof as well as efficient and portable!

Lazy preallocation

Pre-allocating tensors can be cumbersome and hard to scale if the list of preallocated items varies according to the script configuration. TensorDict solves this in an elegant way. Assume you are working with a function foo() -> TensorDict, e.g.

def foo():
    tensordict = TensorDict({}, batch_size=[])
    tensordict["a"] = torch.randn(3)
    tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
    return tensordict

and you would like to call this function repeatedly. You could do this in two ways. The first would simply be to stack the calls to the function:

tensordict = torch.stack([foo() for _ in range(N)])

However, you could also choose to preallocate the tensordict:

tensordict = TensorDict({}, batch_size=[N])
for i in range(N):
    tensordict[i] = foo()

which also results in a tensordict (when N = 10)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([10, 2]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

When i==0, your empty tensordict will automatically be populated with empty tensors of batch-size N. After that, updates will be written in-place. Note that this would also work with a shuffled series of indices (pre-allocation does not require you to go through the tensordict in an ordered fashion).

Nesting TensorDicts

It is possible to nest tensordict. The only requirement is that the sub-tensordict should be indexable under the parent tensordict, i.e. its batch size should match (but could be longer than) the parent batch size.

We can switch easily between hierarchical and flat representations. For instance, the following code will result in a single-level tensordict with keys "key 1" and "key 2.sub-key":

>>> tensordict = TensorDict({
...     "key 1": torch.ones(3, 4, 5),
...     "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")

Accessing nested tensordicts can be achieved with a single index:

>>> sub_value = tensordict["key 2", "sub-key"]

TensorClass

Content flexibility comes at the cost of predictability. In some cases, developers may be looking for data structure with a more explicit behavior. tensordict provides a dataclass-like decorator that allows for the creation of custom dataclasses that support the tensordict operations:

>>> from tensordict.prototype import tensorclass
>>> import torch
>>>
>>> @tensorclass
... class MyData:
...    image: torch.Tensor
...    mask: torch.Tensor
...    label: torch.Tensor
...
...    def mask_image(self):
...        return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)
...
...    def select_label(self, label):
...        return self[self.label == label]
...
>>> images = torch.randn(100, 3, 64, 64)
>>> label = torch.randint(10, (100,))
>>> mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)
>>>
>>> data = MyData(images, mask, label=label, batch_size=[100])
>>>
>>> print(data.select_label(1))
MyData(
    image=Tensor(torch.Size([11, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([11]), dtype=torch.int64),
    mask=Tensor(torch.Size([11, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([11]),
    device=None,
    is_shared=False)
>>> print(data.mask_image().shape)
torch.Size([100, 6117])
>>> print(data.reshape(10, 10))
MyData(
    image=Tensor(torch.Size([10, 10, 3, 64, 64]), dtype=torch.float32),
    label=Tensor(torch.Size([10, 10]), dtype=torch.int64),
    mask=Tensor(torch.Size([10, 10, 1, 64, 64]), dtype=torch.bool),
    batch_size=torch.Size([10, 10]),
    device=None,
    is_shared=False)

As this example shows, one can write a specific data structures with dedicated methods while still enjoying the TensorDict artifacts such as shape operations (e.g. reshape or permutations), data manipulation (indexing, cat and stack) or calling arbitrary functions through the apply method (and many more).

Tensorclasses support nesting and, in fact, all the TensorDict features.

Installation

With Pip:

To install the latest stable version of tensordict, simply run

pip install tensordict

This will work with Python 3.7 and upward as well as PyTorch 1.12 and upward.

To enjoy the latest features, one can use

pip install tensordict-nightly

With Conda:

Install tensordict from conda-forge channel.

conda install -c conda-forge tensordict

Citation

If you're using TensorDict, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Disclaimer

TensorDict is at the beta-stage, meaning that there may be bc-breaking changes introduced, but they should come with a warranty. Hopefully these should not happen too often, as the current roadmap mostly involves adding new features and building compatibility with the broader PyTorch ecosystem.

License

TensorDict is licensed under the MIT License. See LICENSE for details.

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
78,312
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
21,700
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,495
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
7,713
star
5

captum

Model interpretability and understanding for PyTorch
Python
4,482
star
6

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,443
star
7

serve

Serve, optimize and scale PyTorch models in production
Java
3,969
star
8

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,426
star
9

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,340
star
10

glow

Compiler for Neural Network hardware accelerators
C++
3,116
star
11

torchtune

A Native-PyTorch Library for LLM Fine-tuning
Python
2,946
star
12

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
2,920
star
13

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,355
star
14

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,340
star
15

xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
C++
2,301
star
16

rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Python
1,768
star
17

torchrec

Pytorch domain library for recommendation systems
Python
1,683
star
18

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,606
star
19

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,582
star
20

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,506
star
21

android-demo-app

PyTorch android examples of usage in applications
Java
1,392
star
22

functorch

functorch is JAX-like composable function transforms for PyTorch.
Jupyter Notebook
1,363
star
23

hub

Submission to https://pytorch.org/hub/
Python
1,360
star
24

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,059
star
25

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,050
star
26

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
945
star
27

extension-cpp

C++ extensions in PyTorch
Python
924
star
28

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
913
star
29

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
891
star
30

translate

Translate - a PyTorch Language Library
Python
811
star
31

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
759
star
32

elastic

PyTorch elastic training
Python
725
star
33

torcharrow

High performance model preprocessing library on PyTorch
Python
625
star
34

ios-demo-app

PyTorch iOS examples
Swift
578
star
35

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
578
star
36

PiPPy

Pipeline Parallelism for PyTorch
Python
538
star
37

tvm

TVM integration into PyTorch
C++
450
star
38

contrib

Implementations of ideas from recent papers
Python
388
star
39

ort

Accelerate PyTorch models with ONNX Runtime
Python
346
star
40

builder

Continuous builder and binary build scripts for pytorch
Shell
319
star
41

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
318
star
42

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
284
star
43

extension-ffi

Examples of C extensions for PyTorch
Python
254
star
44

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
251
star
45

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
237
star
46

pytorch.github.io

The website for PyTorch
HTML
211
star
47

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
197
star
48

cppdocs

PyTorch C++ API Documentation
HTML
186
star
49

torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Python
177
star
50

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
172
star
51

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
164
star
52

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
125
star
53

java-demo

Jupyter Notebook
119
star
54

rfcs

PyTorch RFCs (experimental)
110
star
55

torchdistx

Torch Distributed Experimental
Python
109
star
56

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
109
star
57

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
97
star
58

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
91
star
59

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
61
star
60

maskedtensor

MaskedTensors for PyTorch
Python
38
star
61

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
62

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
63

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
10
star
64

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
9
star
65

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
5
star
66

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
67

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
68

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
69

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star