• Stars
    star
    4,162
  • Rank 10,410 (Top 0.3 %)
  • Language
    Rust
  • License
    Apache License 2.0
  • Created almost 6 years ago
  • Updated 5 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Rust bindings for the C++ api of PyTorch.

tch-rs

Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

Build Status Latest version Documentation Dependency Status License changelog

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ PyTorch library (libtorch) in version v2.0.0 to be available on your system. You can either:

  • Use the system-wide libtorch installation (default).
  • Install libtorch manually and let the build script know about it via the LIBTORCH environment variable.
  • Use a Python PyTorch install, to do this set LIBTORCH_USE_PYTORCH=1.
  • When a system-wide libtorch can't be found and LIBTORCH is not set, the build script can download a pre-built binary version of libtorch by using the download-libtorch feature. By default a CPU version is used. The TORCH_CUDA_VERSION environment variable can be set to cu117 in order to get a pre-built binary using CUDA 11.7.

System-wide Libtorch

On linux platforms, the build script will look for a system-wide libtorch library in /usr/lib/libtorch.so.

Python PyTorch Install

If the LIBTORCH_USE_PYTORCH environment variable is set, the active python interpreter is called to retrieve information about the torch python package. This version is then linked against.

Libtorch Manual Install

  • Get libtorch from the PyTorch website download section and extract the content of the zip file.
  • For Linux and macOS users, add the following to your .bashrc or equivalent, where /path/to/libtorch is the path to the directory that was created when unzipping the file.
export LIBTORCH=/path/to/libtorch

The header files location can also be specified separately from the shared library via the following:

# LIBTORCH_INCLUDE must contains `include` directory.
export LIBTORCH_INCLUDE=/path/to/libtorch/
# LIBTORCH_LIB must contains `lib` directory.
export LIBTORCH_LIB=/path/to/libtorch/
  • For Windows users, assuming that X:\path\to\libtorch is the unzipped libtorch directory.

    • Navigate to Control Panel -> View advanced system settings -> Environment variables.
    • Create the LIBTORCH variable and set it to X:\path\to\libtorch.
    • Append X:\path\to\libtorch\lib to the Path variable.

    If you prefer to temporarily set environment variables, in PowerShell you can run

$Env:LIBTORCH = "X:\path\to\libtorch"
$Env:Path += ";X:\path\to\libtorch\lib"
  • You should now be able to run some examples, e.g. cargo run --example basics.

Windows Specific Notes

As per the pytorch docs the Windows debug and release builds are not ABI-compatible. This could lead to some segfaults if the incorrect version of libtorch is used.

It is recommended to use the MSVC Rust toolchain (e.g. by installing stable-x86_64-pc-windows-msvc via rustup) rather than a MinGW based one as PyTorch has compatibilities issues with MinGW.

Static Linking

When setting environment variable LIBTORCH_STATIC=1, libtorch is statically linked rather than using the dynamic libraries. The pre-compiled artifacts don't seem to include libtorch.a by default so this would have to be compiled manually, e.g. via the following:

git clone -b v2.0.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
cd pytorch-static
USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build
# export LIBTORCH to point at the build directory in pytorch-static.

Examples

Basic Tensor Operations

This crate provides a tensor type which wraps PyTorch tensors. Here is a minimal example of how to perform some tensor operations.

use tch::Tensor;

fn main() {
    let t = Tensor::from_slice(&[3, 1, 4, 1, 5]);
    let t = t * 2;
    t.print();
}

Training a Model via Gradient Descent

PyTorch provides automatic differentiation for most tensor operations it supports. This is commonly used to train models using gradient descent. The optimization is performed over variables which are created via a nn::VarStore by defining their shapes and initializations.

In the example below my_module uses two variables x1 and x2 which initial values are 0. The forward pass applied to tensor xs returns xs * x1 + exp(xs) * x2.

Once the model has been generated, a nn::Sgd optimizer is created. Then on each step of the training loop:

  • The forward pass is applied to a mini-batch of data.
  • A loss is computed as the mean square error between the model output and the mini-batch ground truth.
  • Finally an optimization step is performed: gradients are computed and variables from the VarStore are modified accordingly.
use tch::nn::{Module, OptimizerConfig};
use tch::{kind, nn, Device, Tensor};

fn my_module(p: nn::Path, dim: i64) -> impl nn::Module {
    let x1 = p.zeros("x1", &[dim]);
    let x2 = p.zeros("x2", &[dim]);
    nn::func(move |xs| xs * &x1 + xs.exp() * &x2)
}

fn gradient_descent() {
    let vs = nn::VarStore::new(Device::Cpu);
    let my_module = my_module(vs.root(), 7);
    let mut opt = nn::Sgd::default().build(&vs, 1e-2).unwrap();
    for _idx in 1..50 {
        // Dummy mini-batches made of zeros.
        let xs = Tensor::zeros(&[7], kind::FLOAT_CPU);
        let ys = Tensor::zeros(&[7], kind::FLOAT_CPU);
        let loss = (my_module.forward(&xs) - ys).pow_tensor_scalar(2).sum(kind::Kind::Float);
        opt.backward_step(&loss);
    }
}

Writing a Simple Neural Network

The nn api can be used to create neural network architectures, e.g. the following code defines a simple model with one hidden layer and trains it on the MNIST dataset using the Adam optimizer.

use anyhow::Result;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};

const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;

fn net(vs: &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(
            vs / "layer1",
            IMAGE_DIM,
            HIDDEN_NODES,
            Default::default(),
        ))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}

pub fn run() -> Result<()> {
    let m = tch::vision::mnist::load_dir("data")?;
    let vs = nn::VarStore::new(Device::Cpu);
    let net = net(&vs.root());
    let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
    for epoch in 1..200 {
        let loss = net
            .forward(&m.train_images)
            .cross_entropy_for_logits(&m.train_labels);
        opt.backward_step(&loss);
        let test_accuracy = net
            .forward(&m.test_images)
            .accuracy_for_logits(&m.test_labels);
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            f64::from(&loss),
            100. * f64::from(&test_accuracy),
        );
    }
    Ok(())
}

More details on the training loop can be found in the detailed tutorial.

Using some Pre-Trained Model

The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.

The example can then be run via the following command:

cargo run --example pretrained-models -- resnet18.ot tiger.jpg

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

    // First the image is loaded and resized to 224x224.
    let image = imagenet::load_image_and_resize(image_file)?;

    // A variable store is created to hold the model parameters.
    let vs = tch::nn::VarStore::new(tch::Device::Cpu);

    // Then the model is built on this variable store, and the weights are loaded.
    let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
    vs.load(weight_file)?;

    // Apply the forward pass of the model to get the logits and convert them
    // to probabilities via a softmax.
    let output = resnet18
        .forward_t(&image.unsqueeze(0), /*train=*/ false)
        .softmax(-1);

    // Finally print the top 5 categories and their associated probabilities.
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }

Importing Pre-Trained Weights from PyTorch Using SafeTensors

safetensors is a new simple format by HuggingFace for storing tensors. It does not rely on Python's pickle module, and therefore the tensors are not bound to the specific classes and the exact directory structure used when the model is saved. It is also zero-copy, which means that reading the file will require no more memory than the original file.

For more information on safetensors, please check out https://github.com/huggingface/safetensors

Installing safetensors

You can install safetensors via the pip manager:

pip install safetensors

Exporting weights in PyTorch

import torchvision
from safetensors import torch as stt

model = torchvision.models.resnet18(pretrained=True)
stt.save_file(model.state_dict(), 'resnet18.safetensors')

Note: the filename of the export must be named with a .safetensors suffix for it to be properly decoded by tch.

Importing weights in tch

use anyhow::Result;
use tch::{
	Device,
	Kind,
	nn::VarStore,
	vision::{
		imagenet,
		resnet::resnet18,
	}
};

fn main() -> Result<()> {
	// Create the model and load the pre-trained weights
	let mut vs = VarStore::new(Device::cuda_if_available());
	let model = resnet18(&vs.root(), 1000);
	vs.load("resnet18.safetensors")?;
	
	// Load the image file and resize it to the usual imagenet dimension of 224x224.
	let image = imagenet::load_image_and_resize224("dog.jpg")?
		.to_device(vs.device());

	// Apply the forward pass of the model to get the logits
	let output = image
		.unsqueeze(0)
		.apply_t(&model, false)
		.softmax(-1, Kind::Float);
	
	// Print the top 5 categories for this image.
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }
    
    Ok(())
}

Further examples include:

External material:

  • A tutorial showing how to use Torch to compute option prices and greeks.
  • tchrs-opencv-webcam-inference uses tch-rs and opencv to run inference on a webcam feed for some Python trained model based on mobilenet v3.

FAQ

What are the best practices for Python to Rust model translations?

See some details in this thread.

How to get this to work on a M1/M2 mac?

Check this issue.

Compilation is slow, torch-sys seems to be rebuilt every time cargo gets run.

See this issue, this could be caused by rust-analyzer not knowing about the proper environment variables like LIBTORCH and LD_LIBRARY_PATH.

Using Rust/tch code from Python.

It is possible to call Rust/tch code from Python via PyO3, tch-ext provides an example of such a Python extension.

Error loading shared libraries.

If you get an error about not finding some shared libraries when running the generated binaries (e.g. error while loading shared libraries: libtorch_cpu.so: cannot open shared object file: No such file or directory). You can try adding the following to your .bashrc where /path/to/libtorch is the path to your libtorch install.

# For Linux
export LD_LIBRARY_PATH=/path/to/libtorch/lib:$LD_LIBRARY_PATH
# For macOS
export DYLD_LIBRARY_PATH=/path/to/libtorch/lib:$DYLD_LIBRARY_PATH

License

tch-rs is distributed under the terms of both the MIT license and the Apache license (version 2.0), at your option.

See LICENSE-APACHE, LICENSE-MIT for more details.

More Repositories

1

diffusers-rs

An implementation of the diffusers api in Rust
Rust
521
star
2

ocaml-torch

OCaml bindings for PyTorch
OCaml
412
star
3

tensorflow-ocaml

OCaml bindings for TensorFlow
OCaml
283
star
4

deep-models

Implementation of a couple deep learning models using TensorFlow
Python
145
star
5

mamba.rs

Rust
121
star
6

xla-rs

Experimentation using the xla compiler from rust
Rust
87
star
7

npy-ocaml

Numpy file format support for ocaml.
OCaml
41
star
8

ocaml-arrow

OCaml
34
star
9

ocaml-rust

Safe OCaml-Rust Foreign Function Interface
Rust
34
star
10

ocaml-wasmtime

OCaml WebAssembly runtime powered by Wasmtime
OCaml
34
star
11

tch-ext

Sample Python extension using Rust/PyO3/tch to interact with PyTorch
Rust
31
star
12

ocaml-matplotlib

Plotting for ocaml based on matplotlib.pyplot
OCaml
30
star
13

btc-ocaml

A toy implementation of the bitcoin protocol in ocaml.
OCaml
29
star
14

ocaml-xla

XLA (Accelerated Linear Algebra) bindings for OCaml
OCaml
28
star
15

ocaml-dataframe

Simple and type-safe dataframe api implemented in pure ocaml
OCaml
25
star
16

ocaml-bert

Transformer-based models for Natural Language Processing in OCaml
OCaml
23
star
17

binprot-rs

Bin_prot binary protocols in Rust
Rust
19
star
18

ocaml-onnx

OCaml ONNX runtime powered by onnxruntime
C
18
star
19

ocaml-tqdm

An ocaml progress bar library similar to https://tqdm.github.io
OCaml
17
star
20

rsexp

S-expression parsing and writing in Rust
Rust
17
star
21

sphn

python bindings for symphonia/opus - read various audio formats from python and write opus files
Rust
16
star
22

tboard-rs

Read and write tensorboard data using Rust
Rust
16
star
23

ProjectEuler

Python
15
star
24

glim

Rust
15
star
25

ocaml-minipy

Naive interpreter for a Python like language
OCaml
13
star
26

syncarp

An async rpc implementation based on tokio and compatible with OCaml Async_rpc
Rust
12
star
27

ocaml.jl

Prototype code for some Julia-OCaml bindings
OCaml
12
star
28

ocaml-tensorflow-eager

OCaml bindings for TensorFlow Eager mode
OCaml
11
star
29

wtensor

Experiments around a webgpu based tensor library
Rust
9
star
30

hojo

A small python library to run iterators in a separate process
Rust
9
star
31

LaurentMazare.github.io

JavaScript
8
star
32

timens-rs

Simple and efficient time representation in Rust.
Rust
7
star
33

ocaml-smbus

C
6
star
34

cmt-fun

OCaml
6
star
35

serde-binprot

Rust binprot serialization using serde
Rust
6
star
36

ocaml-jupyter-async

An OCaml kernel for Jupyter using async.
OCaml
5
star
37

jax-flash-attn2

JAX bindings for the flash-attention2 kernels
C++
5
star
38

jax-flash-attn3

JAX bindings for the flash-attention3 kernels
C++
5
star
39

openai-gym-ocaml

OCaml
4
star
40

ocaml-tensorboard

Write tensorboard compatible log files from ocaml
OCaml
4
star
41

ocaml-rust-stubs

OCaml
4
star
42

ocaml-rpi-gpio

ocaml api for raspberry pi gpio access
C
3
star
43

ocaml-gym

Bindings for OpenAI Gym using the Python C API
OCaml
3
star
44

ogg-table

Ogg-vorbis reader with fast random access
Rust
3
star
45

ocaml-rplidar

RPLidar A1M8 ocaml library
OCaml
2
star
46

ocamldate

Very simple ocaml date implementation
OCaml
1
star