• Stars
    star
    2,241
  • Rank 20,558 (Top 0.5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated about 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

pytorch Documentation Benchmarks codecov Twitter Follow Python 3.7, 3.8 GitHub license pypi version pypi nightly version Downloads Downloads

TorchRL

Documentation | TensorDict | Features | Examples, tutorials and demos | Citation | Installation | Asking a question | Contributing

TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.

It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.

On the low-level end, torchrl comes with a set of highly re-usable functionals for cost functions, returns and data processing.

TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the full paper for a more curated description of the library.

Documentation and knowledge base

The TorchRL documentation can be found here. It contains tutorials and the API reference.

TorchRL also provides a RL knowledge base to help you debug your code, or simply learn the basics of RL. Check it out here.

We have some introductory videos for you to get to know the library better, check them out:

Writing simplified and portable RL codebase with TensorDict

RL algorithms are very heterogeneous, and it can be hard to recycle a codebase across settings (e.g. from online to offline, from state-based to pixel-based learning). TorchRL solves this problem through TensorDict, a convenient data structure(1) that can be used to streamline one's RL codebase. With this tool, one can write a complete PPO training script in less than 100 lines of code!

Code
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
    LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

env = GymEnv("Pendulum-v1")
model = TensorDictModule(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 2),
        NormalParamExtractor()
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"]
)
critic = ValueOperator(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 1),
    ),
    in_keys=["observation"],
)
actor = ProbabilisticActor(
    model,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={"min": -1.0, "max": 1.0},
    return_log_prob=True
    )
buffer = TensorDictReplayBuffer(
    LazyTensorStorage(1000),
    SamplerWithoutReplacement()
    )
collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=1000,
    total_frames=1_000_000
    )
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
for data in collector:  # collect data
    for epoch in range(10):
        adv_fn(data)  # compute advantage
        buffer.extend(data.view(-1))
        for i in range(20):  # consume data
            sample = buffer.sample(50)  # mini-batch
            loss_vals = loss_fn(sample)
            loss_val = sum(
                value for key, value in loss_vals.items() if
                key.startswith("loss")
                )
            loss_val.backward()
            optim.step()
            optim.zero_grad()
    print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")

Here is an example of how the environment API relies on tensordict to carry data from one function to another during a rollout execution: Alt Text

TensorDict makes it easy to re-use pieces of code across environments, models and algorithms.

Code

For instance, here's how to code a rollout in TorchRL:

- obs, done = env.reset()
+ tensordict = env.reset()
policy = SafeModule(
    model,
    in_keys=["observation_pixels", "observation_vector"],
    out_keys=["action"],
)
out = []
for i in range(n_steps):
-     action, log_prob = policy(obs)
-     next_obs, reward, done, info = env.step(action)
-     out.append((obs, next_obs, action, log_prob, reward, done))
-     obs = next_obs
+     tensordict = policy(tensordict)
+     tensordict = env.step(tensordict)
+     out.append(tensordict)
+     tensordict = step_mdp(tensordict)  # renames next_observation_* keys to observation_*
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
+ out = torch.stack(out, 0)  # TensorDict supports multiple tensor operations

Using this, TorchRL abstracts away the input / output signatures of the modules, env, collectors, replay buffers and losses of the library, allowing all primitives to be easily recycled across settings.

Code

Here's another example of an off-policy training loop in TorchRL (assuming that a data collector, a replay buffer, a loss and an optimizer have been instantiated):

- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
+ for i, tensordict in enumerate(collector):
-     replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
+     replay_buffer.add(tensordict)
    for j in range(num_optim_steps):
-         obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
-         loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
+         tensordict = replay_buffer.sample(batch_size)
+         loss = loss_fn(tensordict)
        loss.backward()
        optim.step()
        optim.zero_grad()

This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.

TensorDict supports multiple tensor operations on its device and shape (the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):

Code
# stack and cat
tensordict = torch.stack(list_of_tensordicts, 0)
tensordict = torch.cat(list_of_tensordicts, 0)
# reshape
tensordict = tensordict.view(-1)
tensordict = tensordict.permute(0, 2, 1)
tensordict = tensordict.unsqueeze(-1)
tensordict = tensordict.squeeze(-1)
# indexing
tensordict = tensordict[:2]
tensordict[:, 2] = sub_tensordict
# device and memory location
tensordict.cuda()
tensordict.to("cuda:1")
tensordict.share_memory_()

TensorDict comes with a dedicated tensordict.nn module that contains everything you might need to write your model with it. And it is functorch and torch.compile compatible!

Code
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
- out = transformer_model(src, tgt)
+ td_module(tensordict)
+ out = tensordict["out"]

The TensorDictSequential class allows to branch sequences of nn.Module instances in a highly modular way. For instance, here is an implementation of a transformer using the encoder and decoder blocks:

encoder_module = TransformerEncoder(...)
encoder = TensorDictSequential(encoder_module, in_keys=["src", "src_mask"], out_keys=["memory"])
decoder_module = TransformerDecoder(...)
decoder = TensorDictModule(decoder_module, in_keys=["tgt", "memory"], out_keys=["output"])
transformer = TensorDictSequential(encoder, decoder)
assert transformer.in_keys == ["src", "src_mask", "tgt"]
assert transformer.out_keys == ["memory", "output"]

TensorDictSequential allows to isolate subgraphs by querying a set of desired input / output keys:

transformer.select_subsequence(out_keys=["memory"])  # returns the encoder
transformer.select_subsequence(in_keys=["tgt", "memory"])  # returns the decoder

Check TensorDict tutorials to learn more!

Features

  • A common interface for environments which supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution (e.g. Model-based environments). The batched environments containers allow parallel execution(2). A common PyTorch-first class of tensor-specification class is also provided. TorchRL's environments API is simple but stringent and specific. Check the documentation and tutorial to learn more!

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    env_parallel = ParallelEnv(4, env_make)  # creates 4 envs in parallel
    tensordict = env_parallel.rollout(max_steps=20, policy=None)  # random rollout (no policy given)
    assert tensordict.shape == [4, 20]  # 4 envs, 20 steps rollout
    env_parallel.action_spec.is_in(tensordict["action"])  # spec check returns True
  • multiprocess and distributed data collectors(2) that work synchronously or asynchronously. Through the use of TensorDict, TorchRL's training loops are made very similar to regular training loops in supervised learning (although the "dataloader" -- read data collector -- is modified on-the-fly):

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    collector = MultiaSyncDataCollector(
        [env_make, env_make],
        policy=policy,
        devices=["cuda:0", "cuda:0"],
        total_frames=10000,
        frames_per_batch=50,
        ...
    )
    for i, tensordict_data in enumerate(collector):
        loss = loss_module(tensordict_data)
        loss.backward()
        optim.step()
        optim.zero_grad()
        collector.update_policy_weights_()

    Check our distributed collector examples to learn more about ultra-fast data collection with TorchRL.

  • efficient(2) and generic(1) replay buffers with modularized storage:

    Code
    storage = LazyMemmapStorage(  # memory-mapped (physical) storage
        cfg.buffer_size,
        scratch_dir="/tmp/"
    )
    buffer = TensorDictPrioritizedReplayBuffer(
        alpha=0.7,
        beta=0.5,
        collate_fn=lambda x: x,
        pin_memory=device != torch.device("cpu"),
        prefetch=10,  # multi-threaded sampling
        storage=storage
    )

    Replay buffers are also offered as wrappers around common datasets for offline RL:

    Code
    from torchrl.data.replay_buffers import SamplerWithoutReplacement
    from torchrl.data.datasets.d4rl import D4RLExperienceReplay
    data = D4RLExperienceReplay(
        "maze2d-open-v0",
        split_trajs=True,
        batch_size=128,
        sampler=SamplerWithoutReplacement(drop_last=True),
    )
    for sample in data:  # or alternatively sample = data.sample()
        fun(sample)
  • cross-library environment transforms(1), executed on device and in a vectorized fashion(2), which process and prepare the data coming out of the environments to be used by the agent:

    Code
    env_make = lambda: GymEnv("Pendulum-v1", from_pixels=True)
    env_base = ParallelEnv(4, env_make, device="cuda:0")  # creates 4 envs in parallel
    env = TransformedEnv(
        env_base,
        Compose(
            ToTensorImage(),
            ObservationNorm(loc=0.5, scale=1.0)),  # executes the transforms once and on device
    )
    tensordict = env.reset()
    assert tensordict.device == torch.device("cuda:0")

    Other transforms include: reward scaling (RewardScaling), shape operations (concatenation of tensors, unsqueezing etc.), contatenation of successive operations (CatFrames), resizing (Resize) and many more.

    Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes it easy to add and remove them at will:

    env.insert_transform(0, NoopResetEnv())  # inserts the NoopResetEnv transform at the index 0

    Nevertheless, transforms can access and execute operations on the parent environment:

    transform = env.transform[1]  # gathers the second transform of the list
    parent_env = transform.parent  # returns the base environment of the second transform, i.e. the base env + the first transform
  • various tools for distributed learning (e.g. memory mapped tensors)(2);

  • various architectures and models (e.g. actor-critic)(1):

    Code
    # create an nn.Module
    common_module = ConvNet(
        bias_last_layer=True,
        depth=None,
        num_cells=[32, 64, 64],
        kernel_sizes=[8, 4, 3],
        strides=[4, 2, 1],
    )
    # Wrap it in a SafeModule, indicating what key to read in and where to
    # write out the output
    common_module = SafeModule(
        common_module,
        in_keys=["pixels"],
        out_keys=["hidden"],
    )
    # Wrap the policy module in NormalParamsWrapper, such that the output
    # tensor is split in loc and scale, and scale is mapped onto a positive space
    policy_module = SafeModule(
        NormalParamsWrapper(
            MLP(num_cells=[64, 64], out_features=32, activation=nn.ELU)
        ),
        in_keys=["hidden"],
        out_keys=["loc", "scale"],
    )
    # Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
    # SafeProbabilisticModule, indicating how to build the
    # torch.distribution.Distribution object and what to do with it
    policy_module = SafeProbabilisticTensorDictSequential(  # stochastic policy
        policy_module,
        SafeProbabilisticModule(
            in_keys=["loc", "scale"],
            out_keys="action",
            distribution_class=TanhNormal,
        ),
    )
    value_module = MLP(
        num_cells=[64, 64],
        out_features=1,
        activation=nn.ELU,
    )
    # Wrap the policy and value funciton in a common module
    actor_value = ActorValueOperator(common_module, policy_module, value_module)
    # standalone policy from this
    standalone_policy = actor_value.get_policy_operator()
  • exploration wrappers and modules to easily swap between exploration and exploitation(1):

    Code
    policy_explore = EGreedyWrapper(policy)
    with set_exploration_type(ExplorationType.RANDOM):
        tensordict = policy_explore(tensordict)  # will use eps-greedy
    with set_exploration_type(ExplorationType.MODE):
        tensordict = policy_explore(tensordict)  # will not use eps-greedy
  • A series of efficient loss modules and highly vectorized functional return and advantage computation.

    Code

    Loss modules

    from torchrl.objectives import DQNLoss
    loss_module = DQNLoss(value_network=value_network, gamma=0.99)
    tensordict = replay_buffer.sample(batch_size)
    loss = loss_module(tensordict)

    Advantage computation

    from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
    advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
  • a generic trainer class(1) that executes the aforementioned training loop. Through a hooking mechanism, it also supports any logging or data transformation operation at any given time.

  • various recipes to build models that correspond to the environment being deployed.

If you feel a feature is missing from the library, please submit an issue! If you would like to contribute to new features, check our call for contributions and our contribution page.

Examples, tutorials and demos

A series of examples are provided with an illustrative purpose:

and many more to come!

Check the examples markdown directory for more details about handling the various configuration settings.

We also provide tutorials and demos that give a sense of what the library can do.

Citation

If you're using TorchRL, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch}, 
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Installation

Create a conda environment where the packages will be installed.

conda create --name torch_rl python=3.9
conda activate torch_rl

PyTorch

Depending on the use of functorch that you want to make, you may want to install the latest (nightly) PyTorch release or the latest stable version of PyTorch. See here for a detailed list of commands, including pip3 or windows/OSX compatible installation commands.

Torchrl

You can install the latest stable release by using

pip3 install torchrl

This should work on linux and MacOs (not M1). For Windows and M1/M2 machines, one should install the library locally (see below).

The nightly build can be installed via

pip install torchrl-nightly

To install extra dependencies, call

pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils]"

or a subset of these.

Alternatively, as the library is at an early stage, it may be wise to install it in develop mode as this will make it possible to pull the latest changes and benefit from them immediately. Start by cloning the repo:

git clone https://github.com/pytorch/rl

Go to the directory where you have cloned the torchrl repo and install it

cd /path/to/torchrl/
pip install -e .

On M1 machines, this should work out-of-the-box with the nightly build of PyTorch. If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e')) appears, then try

ARCHFLAGS="-arch arm64" python setup.py develop

To run a quick sanity check, leave that directory (e.g. by executing cd ~/) and try to import the library.

python -c "import torchrl"

This should not return any warning or error.

Optional dependencies

The following libraries can be installed depending on the usage one wants to make of torchrl:

# diverse
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher

# rendering
pip3 install moviepy

# deepmind control suite
pip3 install dm_control

# gym, atari games
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame

# tests
pip3 install pytest pyyaml pytest-instafail

# tensorboard
pip3 install tensorboard

# wandb
pip3 install wandb

Troubleshooting

If a ModuleNotFoundError: No module named β€˜torchrl._torchrl errors occurs, it means that the C++ extensions were not installed or not found. One common reason might be that you are trying to import torchrl from within the git repo location. Indeed the following code snippet should return an error if torchrl has not been installed in develop mode:

cd ~/path/to/rl/repo
python -c 'from torchrl.envs.libs.gym import GymEnv'

If this is the case, consider executing torchrl from another location.

On MacOs, we recommend installing XCode first. With Apple Silicon M1 chips, make sure you are using the arm64-built python (e.g. here). Running the following lines of code

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
python collect_env.py

should display

OS: macOS *** (arm64)

and not

OS: macOS **** (x86_64)

Versioning issues can cause error message of the type undefined symbol and such. For these, refer to the versioning issues document for a complete explanation and proposed workarounds.

Asking a question

If you spot a bug in the library, please raise an issue in this repo.

If you have a more generic question regarding RL in PyTorch, post it on the PyTorch forum.

Contributing

Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs. You can checkout the detailed contribution guide here. As mentioned above, a list of open contributions can be found in here.

Contributors are recommended to install pre-commit hooks (using pre-commit install). pre-commit will check for linting related issues when the code is commited locally. You can disable th check by appending -n to your commit command: git commit -m <commit message> -n

Disclaimer

This library is released as a PyTorch beta feature. BC-breaking changes are likely to happen but they will be introduced with a deprecation warranty after a few release cycles.

License

TorchRL is licensed under the MIT License. See LICENSE for details.

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
83,553
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
22,311
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,925
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
8,075
star
5

captum

Model interpretability and understanding for PyTorch
Python
4,781
star
6

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,507
star
7

serve

Serve, optimize and scale PyTorch models in production
Java
4,190
star
8

torchtune

PyTorch native finetuning library
Python
4,163
star
9

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,490
star
10

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,364
star
11

glow

Compiler for Neural Network hardware accelerators
C++
3,197
star
12

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
3,043
star
13

torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
Python
3,040
star
14

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,565
star
15

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,471
star
16

xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
C++
2,469
star
17

torchtitan

A native PyTorch Library for large model training
Python
2,130
star
18

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
1,954
star
19

torchrec

Pytorch domain library for recommendation systems
Python
1,852
star
20

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,666
star
21

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,650
star
22

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,519
star
23

android-demo-app

PyTorch android examples of usage in applications
Java
1,460
star
24

functorch

functorch is JAX-like composable function transforms for PyTorch.
Jupyter Notebook
1,388
star
25

hub

Submission to https://pytorch.org/hub/
Python
1,384
star
26

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,156
star
27

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,112
star
28

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
989
star
29

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
989
star
30

extension-cpp

C++ extensions in PyTorch
Python
980
star
31

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
841
star
32

translate

Translate - a PyTorch Language Library
Python
820
star
33

tensordict

TensorDict is a pytorch dedicated tensor container.
Python
816
star
34

elastic

PyTorch elastic training
Python
728
star
35

PiPPy

Pipeline Parallelism for PyTorch
Python
698
star
36

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
682
star
37

torcharrow

High performance model preprocessing library on PyTorch
Python
641
star
38

ao

PyTorch native quantization and sparsity for training and inference
Python
630
star
39

ios-demo-app

PyTorch iOS examples
Swift
595
star
40

tvm

TVM integration into PyTorch
C++
451
star
41

contrib

Implementations of ideas from recent papers
Python
390
star
42

ort

Accelerate PyTorch models with ONNX Runtime
Python
353
star
43

builder

Continuous builder and binary build scripts for pytorch
Shell
325
star
44

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
319
star
45

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
317
star
46

extension-ffi

Examples of C extensions for PyTorch
Python
257
star
47

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
252
star
48

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
247
star
49

pytorch.github.io

The website for PyTorch
HTML
222
star
50

torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Python
210
star
51

cppdocs

PyTorch C++ API Documentation
HTML
206
star
52

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
204
star
53

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
199
star
54

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
169
star
55

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
142
star
56

java-demo

Jupyter Notebook
126
star
57

rfcs

PyTorch RFCs (experimental)
120
star
58

torchdistx

Torch Distributed Experimental
Python
115
star
59

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
112
star
60

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
105
star
61

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
94
star
62

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
78
star
63

expecttest

Python
71
star
64

torchcodec

PyTorch video decoding
Python
46
star
65

maskedtensor

MaskedTensors for PyTorch
Python
38
star
66

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
67

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
11
star
68

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
69

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
10
star
70

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
6
star
71

docs

This repository is automatically generated to contain the website source for the PyTorch documentation at https//pytorch.org/docs.
HTML
4
star
72

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
73

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
74

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
75

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star