• Stars
    star
    210
  • Rank 187,585 (Top 4 %)
  • Language
    Python
  • License
    Other
  • Created over 2 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.

TorchEval

build status pypi version pypi nightly version bsd license docs

This library is currently in Alpha and currently does not have a stable release. The API may change and may not be backward compatible. If you have suggestions for improvements, please open a GitHub issue. We'd love to hear your feedback.

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.

Installing TorchEval

Requires Python >= 3.8 and PyTorch >= 1.11

From pip:

pip install torcheval

For nighly build version

pip install --pre torcheval-nightly

From source:

git clone https://github.com/pytorch/torcheval
cd torcheval
pip install -r requirements.txt
python setup.py install

Quick Start

Take a look at the quickstart notebook, or fork it on Colab.

There are more examples in the examples directory:

cd torcheval
python examples/simple_example.py

Documentation

Documentation can be found at at pytorch.org/torcheval

Using TorchEval

TorchEval can be run on CPU, GPU, and in a multi-process or multi-GPU setting. Metrics are provided in two interfaces, functional and class based. The functional interfaces can be found in torcheval.metrics.functional and are useful when your program runs in a single process setting. To use multi-process or multi-gpu configurations, the class-based interfaces, found in torcheval.metrics provide a much simpler experience. The class based interfaces also allow you to defer some of the computation of the metric by calling update() multiple times before compute(). This can be advantageous even in a single process setting due to saved computation overhead.

Single Process

For use in a single process program, the simplest use case utilizes a functional metric. We simply import the metric function and feed in our outputs and targets. The example below shows a minimal PyTorch training loop that evaluates the multiclass accuracy of every fourth batch of data.

Functional Version (immediate computation of metric)

import torch
from torcheval.metrics.functional import multiclass_accuracy

NUM_BATCHES = 16
BATCH_SIZE = 8
INPUT_SIZE = 10
NUM_CLASSES = 6
eval_frequency = 4

model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, NUM_CLASSES), torch.nn.ReLU())
optim = torch.optim.Adagrad(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

metric_history = []
for batch in range(NUM_BATCHES):
    input = torch.rand(size=(BATCH_SIZE, INPUT_SIZE))
    target = torch.randint(size=(BATCH_SIZE,), high=NUM_CLASSES)
    outputs = model(input)

    loss = loss_fn(outputs, target)
    optim.zero_grad()
    loss.backward()
    optim.step()

    # metric only computed every 4 batches,
    # data from previous three batches is lost
    if (batch + 1) % eval_frequency == 0:
        metric_history.append(multiclass_accuracy(outputs, target))

Single Process with Deferred Computation

Class Version (enables deferred computation of metric)

import torch
from torcheval.metrics import MulticlassAccuracy

NUM_BATCHES = 16
BATCH_SIZE = 8
INPUT_SIZE = 10
NUM_CLASSES = 6
eval_frequency = 4

model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, NUM_CLASSES), torch.nn.ReLU())
optim = torch.optim.Adagrad(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
metric = MulticlassAccuracy()

metric_history = []
for batch in range(NUM_BATCHES):
    input = torch.rand(size=(BATCH_SIZE, INPUT_SIZE))
    target = torch.randint(size=(BATCH_SIZE,), high=NUM_CLASSES)
    outputs = model(input)

    loss = loss_fn(outputs, target)
    optim.zero_grad()
    loss.backward()
    optim.step()

    # metric only computed every 4 batches,
    # data from previous three batches is included
    metric.update(input, target)
    if (batch + 1) % eval_frequency == 0:
        metric_history.append(metric.compute())
        # remove old data so that the next call
        # to compute is only based off next 4 batches
        metric.reset()

Multi-Process or Multi-GPU

For usage on multiple devices a minimal example is given below. In the normal torch.distributed paradigm, each device is allocated its own process gets a unique numerical ID called a "global rank", counting up from 0.

Class Version (enables deferred computation and multi-processing)

import torch
from torcheval.metrics.toolkit import sync_and_compute
from torcheval.metrics import MulticlassAccuracy

# Using torch.distributed
local_rank = int(os.environ["LOCAL_RANK"]) #rank on local machine, i.e. unique ID within a machine
global_rank = int(os.environ["RANK"]) #rank in global pool, i.e. unique ID within the entire process group
world_size  = int(os.environ["WORLD_SIZE"]) #total number of processes or "ranks" in the entire process group

device = torch.device(
    f"cuda:{local_rank}"
    if torch.cuda.is_available() and torch.cuda.device_count() >= world_size
    else "cpu"
)

metric = MulticlassAccuracy(device=device)
num_epochs, num_batches = 4, 8

for epoch in range(num_epochs):
    for i in range(num_batches):
        input = torch.randint(high=5, size=(10,), device=device)
        target = torch.randint(high=5, size=(10,), device=device)

        # Add data to metric locally
        metric.update(input, target)

        # metric.compute() will returns metric value from
        # all seen data on the local process since last reset()
        local_compute_result = metric.compute()

        # sync_and_compute(metric) syncs metric data across all ranks and computes the metric value
        global_compute_result = sync_and_compute(metric)
        if global_rank == 0:
            print(global_compute_result)

    # metric.reset() clears the data on each process so that subsequent
    # calls to compute() only act on new data
    metric.reset()

See the example directory for more examples.

Contributing

We welcome PRs! See the CONTRIBUTING file.

License

TorchEval is BSD licensed, as found in the LICENSE file.

More Repositories

1

pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Python
83,553
star
2

examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
Python
22,311
star
3

vision

Datasets, Transforms and Models specific to Computer Vision
Python
15,925
star
4

tutorials

PyTorch tutorials.
Jupyter Notebook
8,075
star
5

captum

Model interpretability and understanding for PyTorch
Python
4,781
star
6

ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Python
4,507
star
7

serve

Serve, optimize and scale PyTorch models in production
Java
4,190
star
8

torchtune

PyTorch native finetuning library
Python
4,163
star
9

text

Models, data loaders and abstractions for language processing, powered by PyTorch
Python
3,490
star
10

ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
C++
3,364
star
11

glow

Compiler for Neural Network hardware accelerators
C++
3,197
star
12

botorch

Bayesian optimization in PyTorch
Jupyter Notebook
3,043
star
13

torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
Python
3,040
star
14

TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
Python
2,565
star
15

audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
Python
2,471
star
16

xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
C++
2,469
star
17

rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Python
2,241
star
18

torchtitan

A native PyTorch Library for large model training
Python
2,130
star
19

executorch

On-device AI across mobile, embedded and edge for PyTorch
C++
1,954
star
20

torchrec

Pytorch domain library for recommendation systems
Python
1,852
star
21

opacus

Training PyTorch models with differential privacy
Jupyter Notebook
1,666
star
22

tnt

A lightweight library for PyTorch training tools and utilities
Python
1,650
star
23

QNNPACK

Quantized Neural Network PACKage - mobile-optimized implementation of quantized neural network operators
C
1,519
star
24

android-demo-app

PyTorch android examples of usage in applications
Java
1,460
star
25

functorch

functorch is JAX-like composable function transforms for PyTorch.
Jupyter Notebook
1,388
star
26

hub

Submission to https://pytorch.org/hub/
Python
1,384
star
27

FBGEMM

FB (Facebook) + GEMM (General Matrix-Matrix Multiplication) - https://code.fb.com/ml-applications/fbgemm/
C++
1,156
star
28

data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
Python
1,112
star
29

cpuinfo

CPU INFOrmation library (x86/x86-64/ARM/ARM64, Linux/Windows/Android/macOS/iOS)
C
989
star
30

torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
Python
989
star
31

extension-cpp

C++ extensions in PyTorch
Python
980
star
32

benchmark

TorchBench is a collection of open source benchmarks used to evaluate PyTorch performance.
Python
841
star
33

translate

Translate - a PyTorch Language Library
Python
820
star
34

tensordict

TensorDict is a pytorch dedicated tensor container.
Python
816
star
35

elastic

PyTorch elastic training
Python
728
star
36

PiPPy

Pipeline Parallelism for PyTorch
Python
698
star
37

kineto

A CPU+GPU Profiling library that provides access to timeline traces and hardware performance counters.
HTML
682
star
38

torcharrow

High performance model preprocessing library on PyTorch
Python
641
star
39

ao

PyTorch native quantization and sparsity for training and inference
Python
630
star
40

ios-demo-app

PyTorch iOS examples
Swift
595
star
41

tvm

TVM integration into PyTorch
C++
451
star
42

contrib

Implementations of ideas from recent papers
Python
390
star
43

ort

Accelerate PyTorch models with ONNX Runtime
Python
353
star
44

builder

Continuous builder and binary build scripts for pytorch
Shell
325
star
45

torchx

TorchX is a universal job launcher for PyTorch applications. TorchX is designed to have fast iteration time for training/research and support for E2E production ML pipelines when you're ready.
Python
319
star
46

accimage

high performance image loading and augmenting routines mimicking PIL.Image interface
C
317
star
47

extension-ffi

Examples of C extensions for PyTorch
Python
257
star
48

nestedtensor

[Prototype] Tools for the concurrent manipulation of variably sized Tensors.
Jupyter Notebook
252
star
49

tensorpipe

A tensor-aware point-to-point communication primitive for machine learning
C++
247
star
50

pytorch.github.io

The website for PyTorch
HTML
222
star
51

cppdocs

PyTorch C++ API Documentation
HTML
206
star
52

workshops

This is a repository for all workshop related materials.
Jupyter Notebook
204
star
53

hydra-torch

Configuration classes enabling type-safe PyTorch configuration for Hydra apps
Python
199
star
54

multipy

torch::deploy (multipy for non-torch uses) is a system that lets you get around the GIL problem by running multiple Python interpreters in a single C++ process.
C++
169
star
55

torchsnapshot

A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.
Python
142
star
56

java-demo

Jupyter Notebook
126
star
57

rfcs

PyTorch RFCs (experimental)
120
star
58

torchdistx

Torch Distributed Experimental
Python
115
star
59

extension-script

Example repository for custom C++/CUDA operators for TorchScript
Python
112
star
60

csprng

Cryptographically secure pseudorandom number generators for PyTorch
Batchfile
105
star
61

pytorch_sphinx_theme

PyTorch Sphinx Theme
CSS
94
star
62

test-infra

This repository hosts code that supports the testing infrastructure for the main PyTorch repo. For example, this repo hosts the logic to track disabled tests and slow tests, as well as our continuation integration jobs HUD/dashboard.
TypeScript
78
star
63

expecttest

Python
71
star
64

torchcodec

PyTorch video decoding
Python
46
star
65

maskedtensor

MaskedTensors for PyTorch
Python
38
star
66

add-annotations-github-action

A GitHub action to run clang-tidy and annotate failures
JavaScript
13
star
67

ci-hud

HUD for CI activity on `pytorch/pytorch`, provides a top level view for jobs to easily discern regressions
JavaScript
11
star
68

probot

PyTorch GitHub bot written in probot
TypeScript
11
star
69

ossci-job-dsl

Jenkins job definitions for OSSCI
Groovy
10
star
70

pytorch-integration-testing

Testing downstream libraries using pytorch release candidates
Makefile
6
star
71

docs

This repository is automatically generated to contain the website source for the PyTorch documentation at https//pytorch.org/docs.
HTML
4
star
72

torchhub_testing

Repo to test torchhub. Nothing to see here.
4
star
73

dr-ci

Diagnose and remediate CI jobs
Haskell
2
star
74

pytorch-ci-dockerfiles

Scripts for generating docker images for PyTorch CI
2
star
75

labeler-github-action

GitHub action for labeling issues and pull requests based on conditions
TypeScript
1
star