• Stars
    star
    2,469
  • Rank 18,619 (Top 0.4 %)
  • Language
    C++
  • License
    Other
  • Created about 6 years ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Enabling PyTorch on XLA Devices (e.g. Google TPU)

PyTorch/XLA

Current CI status: GitHub Actions status

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Getting Started

To install PyTorch/XLA a new VM:

pip install torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xm.xla_device())
+
+  # MpDeviceLoader preloads data to the XLA device
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
-    optimizer.step()
+
+    # `xm.optimizer_step` combines gradients across replocas
+    xm.optimizer_step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xmp.spawn automatically selects the correct world size
+  xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(rank, world_size):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='pjrt://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=tpu` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = ddp_model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
     optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, GPU, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

Wheel

Version Cloud TPU VMs Wheel
2.0 (Python 3.8) https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
nightly >= 2023/04/25 (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly >= 2023/04/25 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
older versions
Version Cloud TPU VMs Wheel
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl
1.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl
1.11 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl
1.10 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl
nightly <= 2023/04/25 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Note: For TPU Pod customers using XRT (our legacy runtime), we have custom wheels for torch, torchvision, and torch_xla at https://storage.googleapis.com/tpu-pytorch/wheels/xrt.

Package Cloud TPU VMs Wheel (XRT on Pod, Legacy Only)
torch_xla https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
torch https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch-2.0-cp38-cp38-linux_x86_64.whl
torchvision https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torchvision-2.0-cp38-cp38-linux_x86_64.whl

Version GPU Wheel + Python 3.8
2.0 + CUDA 11.8 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
2.0 + CUDA 11.7 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp38-cp38-linux_x86_64.whl
nightly + CUDA 11.7 <= 2023/04/25 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly + CUDA 11.7 >= 2023/04/25 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.7/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly + CUDA 11.8 <= 2023/04/25 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly + CUDA 11.8 >= 2023/04/25 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Version GPU Wheel + Python 3.7
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl
1.12 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
1.11 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl
nightly https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-nightly-cp37-cp37-linux_x86_64.whl

Version Colab TPU Wheel
2.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

You can also add +yyyymmdd after torch_xla-nightly to get the nightly wheel of a specified date. To get the companion pytorch and torchvision nightly wheel, replace the torch_xla with torch or torchvision on above wheel links.

Installing libtpu (before PyTorch/XLA 2.0)

For PyTorch/XLA release r2.0 and older and when developing PyTorch/XLA, install the libtpu pip package with the following command:

pip3 install torch_xla[tpuvm]

This is only required on Cloud TPU VMs.

Docker

Version Cloud TPU VMs Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm
nightly python 3.10 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm
nightly python 3.8 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm
nightly python 3.10(>= 2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD
nightly python 3.8(>= 2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD
nightly at date(< 2023/04/25) gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD

Version GPU CUDA 11.8 + Python 3.8 Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8
nightly at date(>=2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD
nightly at date(<2023/04/25) gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD

Version GPU CUDA 11.7 + Python 3.8 Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7
nightly at date(>=2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.7_YYYYMMDD
nightly at date(<2023/04/25) gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.7_YYYYMMDD

Version GPU CUDA 11.2 + Python 3.8 Docker
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2

Version GPU CUDA 11.2 + Python 3.7 Docker
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2
1.12 gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2

To run on compute instances with GPUs.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to [email protected]. For questions directed at Google, please send an email to [email protected]. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
83,553
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
22,311
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,925
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
8,075
star
5

captum

Model interpretability and understanding for PyTorch
Python
4,781
star
6

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,507
star
7

serve

Serve, optimize and scale PyTorch models in production
Java
4,190
star
8

torchtune

PyTorch native finetuning library
Python
4,163
star
9

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,490
star
10

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,364
star
11

glow

Compiler for Neural Network hardware accelerators
C++
3,197
star
12

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
3,043
star
13

torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
Python
3,040
star
14

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,565
star
15

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,471
star
16

rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Python
2,241
star
17

torchtitan

A native PyTorch Library for large model training
Python
2,130
star
18

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
1,954
star
19

torchrec

Pytorch domain library for recommendation systems
Python
1,852
star
20

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,666
star
21

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,650
star
22

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,519
star
23

android-demo-app

PyTorch android examples of usage in applications
Java
1,460
star
24

functorch

functorch is JAX-like composable function transforms for PyTorch.
Jupyter Notebook
1,388
star
25

hub

Submission to https://pytorch.org/hub/
Python
1,384
star
26

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,156
star
27

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,112
star
28

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
989
star
29

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
989
star
30

extension-cpp

C++ extensions in PyTorch
Python
980
star
31

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
841
star
32

translate

Translate - a PyTorch Language Library
Python
820
star
33

tensordict

TensorDict is a pytorch dedicated tensor container.
Python
816
star
34

elastic

PyTorch elastic training
Python
728
star
35

PiPPy

Pipeline Parallelism for PyTorch
Python
698
star
36

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
682
star
37

torcharrow

High performance model preprocessing library on PyTorch
Python
641
star
38

ao

PyTorch native quantization and sparsity for training and inference
Python
630
star
39

ios-demo-app

PyTorch iOS examples
Swift
595
star
40

tvm

TVM integration into PyTorch
C++
451
star
41

contrib

Implementations of ideas from recent papers
Python
390
star
42

ort

Accelerate PyTorch models with ONNX Runtime
Python
353
star
43

builder

Continuous builder and binary build scripts for pytorch
Shell
325
star
44

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
319
star
45

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
317
star
46

extension-ffi

Examples of C extensions for PyTorch
Python
257
star
47

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
252
star
48

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
247
star
49

pytorch.github.io

The website for PyTorch
HTML
222
star
50

torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Python
210
star
51

cppdocs

PyTorch C++ API Documentation
HTML
206
star
52

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
204
star
53

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
199
star
54

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
169
star
55

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
142
star
56

java-demo

Jupyter Notebook
126
star
57

rfcs

PyTorch RFCs (experimental)
120
star
58

torchdistx

Torch Distributed Experimental
Python
115
star
59

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
112
star
60

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
105
star
61

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
94
star
62

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
78
star
63

expecttest

Python
71
star
64

torchcodec

PyTorch video decoding
Python
46
star
65

maskedtensor

MaskedTensors for PyTorch
Python
38
star
66

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
67

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
11
star
68

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
69

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
10
star
70

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
6
star
71

docs

This repository is automatically generated to contain the website source for the PyTorch documentation at https//pytorch.org/docs.
HTML
4
star
72

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
73

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
74

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
75

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star