• Stars
    star
    1,331
  • Rank 34,018 (Top 0.7 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 3 years ago
  • Updated 11 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Welcome! For new projects I now strongly recommend using my newer jaxtyping project instead. It supports PyTorch, doesn't actually depend on JAX, and unlike TorchTyping it is compatible with static type checkers. :)


Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.


Installation

pip install torchtyping

Requires Python >=3.7 and PyTorch >=1.7.0.

If using typeguard then it must be a version <3.0.0.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as per named tensors;
  • arbitrary number of batch dimensions with ...;
  • ...plus anything else you like, as torchtyping is highly extensible.

If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.

If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of shape, dtype, layout, details are optional.

  • The shape argument can be any of:
    • An int: the dimension must be of exactly this size. If it is -1 then any size is allowed.
    • A str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A ...: An arbitrary number of dimensions of any sizes.
    • A str: int pair (technically it's a slice), combining both str and int behaviour. (Just a str on its own is equivalent to str: -1.)
    • A str: str pair, in which case the size of the dimension passed at runtime will be bound to both names, and all dimensions with either name must have the same size. (Some people like to use this as a way to associate multiple names with a dimension, for extra documentation purposes.)
    • A str: ... pair, in which case the multiple dimensions corresponding to ... will be bound to the name specified by str, and again checked for consistency between arguments.
    • None, which when used in conjunction with is_named below, indicates a dimension that must not have a name in the sense of named tensors.
    • A None: int pair, combining both None and int behaviour. (Just a None on its own is equivalent to None: -1.)
    • A None: str pair, combining both None and str behaviour. (That is, it must not have a named dimension, but must be of a size consistent with other uses of the string.)
    • A typing.Any: Any size is allowed for this dimension (equivalent to -1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example TensorType[-1, -1, -1] for a three-dimensional tensor.
  • The dtype argument can be any of:
    • torch.float32, torch.float64 etc.
    • int, bool, float, which are converted to their corresponding PyTorch types. float is specifically interpreted as torch.get_default_dtype(), which is usually float32.
  • The layout argument can be either torch.strided or torch.sparse_coo, for dense and sparse tensors respectively.
  • The details argument offers a way to pass an arbitrary number of additional flags that customise and extend torchtyping. Two flags are built-in by default. torchtyping.is_named causes the names of tensor dimensions to be checked, and torchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. TensorType[torch.float32].) For discussion on how to customise torchtyping with your own details, see the further documentation.
  • Check multiple things at once by just putting them all together inside a single []. For example TensorType["batch": ..., "length", "channels", float, is_named].
torchtyping.patch_typeguard()

torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using @typeguard.typechecked, then torchtyping.patch_typeguard() should be called any time before using @typeguard.typechecked. For example you could call it at the start of each file using torchtyping.
  • If using typeguard.importhook.install_import_hook, then torchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could call torchtyping.patch_typeguard() just once, at the same time as the typeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".

Further documentation

See the further documentation for:

  • FAQ;
    • Including flake8 and mypy compatibility;
  • How to write custom extensions to torchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.

More Repositories

1

equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Python
1,743
star
2

diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Python
1,197
star
3

jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Python
914
star
4

NeuralCDE

Code for "Neural Controlled Differential Equations for Irregular Time Series" (Neurips 2020 Spotlight)
Python
576
star
5

torchcde

Differentiable controlled differential equation solvers for PyTorch with GPU support and memory-efficient adjoint backpropagation.
Python
389
star
6

mkposters

Make posters from Markdown files.
Python
317
star
7

sympy2jax

Turn SymPy expressions into trainable JAX expressions.
Python
299
star
8

lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Python
299
star
9

signatory

Differentiable computations of the signature and logsignature transforms, on both CPU and GPU. (ICLR 2021)
C++
246
star
10

optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Python
225
star
11

torchcubicspline

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.
Python
177
star
12

sympytorch

Turning SymPy expressions into PyTorch modules.
Python
122
star
13

quax

Multiple dispatch over abstract array types in JAX.
Python
88
star
14

FasterNeuralDiffEq

Code for "'Hey, that's not an ODE:' Faster ODE Adjoints via Seminorms" (ICML 2021)
Python
84
star
15

Deep-Signature-Transforms

Code for "Deep Signature Transforms" (NeurIPS 2019)
Jupyter Notebook
82
star
16

typst_pyimage

Typst extension, adding support for generating figures using inline Python code
Python
65
star
17

generalised_shapelets

Code for "Generalised Interpretable Shapelets for Irregular Time Series"
Jupyter Notebook
51
star
18

PatModules.jl

A better import/module system for Julia.
Julia
18
star
19

action_update_python_project

Github Action to: Check version / Test / git tag / GitHub Release / Deploy to PyPI
8
star
20

exvoker

A CLI tool. Extract regexes from stdout (e.g. URLs) and invoke commands on them (e.g. open the webpage).
Rust
8
star
21

pytkdocs_tweaks

Some custom tweaks to the results produced by pytkdocs.
Python
5
star
22

Learning-Interpolation

Applying machine learning to help numerically solve the Camassa-Holm equation.
Jupyter Notebook
4
star
23

matching

Round robin matching algorithm.
Python
3
star
24

candle

Simple PyTorch helpers. (I think we've probably all written one of these for ourselves!)
Python
3
star
25

tools

Helpful abstract tools (functions, classes, ... ) for coding in Python.
Python
3
star
26

pdfscraper

Saves a webpage and all linked pdfs.
Python
3
star
27

ktools

Tools for working with Keras.
Python
2
star
28

loccounter

Counts lines of Python code.
Python
2
star
29

Dissertation

Master's Dissertation: Polynomial Approximation of Holomorphic Functions
2
star
30

py2annotate

An extension to Sphinx autodoc to augment Sphinx documentation with type annotations, when using Python 2 style type annotations.
Python
2
star
31

adventuregame

The very start of a game I was toying with before I got distracted by the PhD...
Python
2
star
32

MPE-CDT-Project

A simple machine learning project for weather observations.
Jupyter Notebook
2
star
33

tfext

Some extra stuff for using with TensorFlow.
Python
2
star
34

mkdocs_include_exclude_files

Modify which files MkDocs includes or excludes.
Python
1
star
35

patrick-kidger

1
star
36

rl-test

Python
1
star