• Stars
    star
    675
  • Rank 66,879 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

State Space Models library in JAX

Welcome to DYNAMAX!

Logo

Test Status

Dynamax is a library for probabilistic state space models (SSMs) written in JAX. It has code for inference (state estimation) and learning (parameter estimation) in a variety of SSMs, including:

  • Hidden Markov Models (HMMs)
  • Linear Gaussian State Space Models (aka Linear Dynamical Systems)
  • Nonlinear Gaussian State Space Models
  • Generalized Gaussian State Space Models (with non-Gaussian emission models)

The library consists of a set of core, functionally pure, low-level inference algorithms, as well as a set of model classes which provide a more user-friendly, object-oriented interface. It is compatible with other libraries in the JAX ecosystem, such as optax (used for estimating parameters using stochastic gradient descent), and Blackjax (used for computing the parameter posterior using Hamiltonian Monte Carlo (HMC) or sequential Monte Carlo (SMC)).

Documentation

For tutorials and API documentation, see: https://probml.github.io/dynamax/.

Installation and Testing

To install the latest releast of dynamax from PyPi:

pip install dynamax                 # Install dynamax and core dependencies, or
pip install dynamax[notebooks]      # Install with demo notebook dependencies

To install the latest development branch:

pip install git+https://github.com/probml/dynamax.git

Finally, if you're a developer, you can install dynamax along with the test and documentation dependencies with:

git clone [email protected]:probml/dynamax.git
cd dynamax
pip install -e '.[dev]'

To run the tests:

pytest dynamax                         # Run all tests
pytest dynamax/hmm/inference_test.py   # Run a specific test
pytest -k lgssm                        # Run tests with lgssm in the name

What are state space models?

A state space model or SSM is a partially observed Markov model, in which the hidden state, $z_t$, evolves over time according to a Markov process, possibly conditional on external inputs / controls / covariates, $u_t$, and generates an observation, $y_t$. This is illustrated in the graphical model below.

The corresponding joint distribution has the following form (in dynamax, we restrict attention to discrete time systems):

$$p(y_{1:T}, z_{1:T} | u_{1:T}) = p(z_1 | u_1) p(y_1 | z_1, u_1) \prod_{t=1}^T p(z_t | z_{t-1}, u_t) p(y_t | z_t, u_t)$$

Here $p(z_t | z_{t-1}, u_t)$ is called the transition or dynamics model, and $p(y_t | z_{t}, u_t)$ is called the observation or emission model. In both cases, the inputs $u_t$ are optional; furthermore, the observation model may have auto-regressive dependencies, in which case we write $p(y_t | z_{t}, u_t, y_{1:t-1})$.

We assume that we see the observations $y_{1:T}$, and want to infer the hidden states, either using online filtering (i.e., computing $p(z_t|y_{1:t})$ ) or offline smoothing (i.e., computing $p(z_t|y_{1:T})$ ). We may also be interested in predicting future states, $p(z_{t+h}|y_{1:t})$, or future observations, $p(y_{t+h}|y_{1:t})$, where h is the forecast horizon. (Note that by using a hidden state to represent the past observations, the model can have "infinite" memory, unlike a standard auto-regressive model.) All of these computations can be done efficiently using our library, as we discuss below. In addition, we can estimate the parameters of the transition and emission models, as we discuss below.

More information can be found in these books:

Example usage

Dynamax includes classes for many kinds of SSM. You can use these models to simulate data, and you can fit the models using standard learning algorithms like expectation-maximization (EM) and stochastic gradient descent (SGD). Below we illustrate the high level (object-oriented) API for the case of an HMM with Gaussian emissions. (See this notebook for a runnable version of this code.)

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from dynamax.hidden_markov_model import GaussianHMM

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
num_states = 3
emission_dim = 2
num_timesteps = 1000

# Make a Gaussian HMM and sample data from it
hmm = GaussianHMM(num_states, emission_dim)
true_params, _ = hmm.initialize(key1)
true_states, emissions = hmm.sample(true_params, key2, num_timesteps)

# Make a new Gaussian HMM and fit it with EM
params, props = hmm.initialize(key3, method="kmeans", emissions=emissions)
params, lls = hmm.fit_em(params, props, emissions, num_iters=20)

# Plot the marginal log probs across EM iterations
plt.plot(lls)
plt.xlabel("EM iterations")
plt.ylabel("marginal log prob.")

# Use fitted model for posterior inference
post = hmm.smoother(params, emissions)
print(post.smoothed_probs.shape) # (1000, 3)

JAX allows you to easily vectorize these operations with vmap. For example, you can sample and fit to a batch of emissions as shown below.

from functools import partial
from jax import vmap

num_seq = 200
batch_true_states, batch_emissions = \
    vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))(
        jr.split(key2, num_seq))
print(batch_true_states.shape, batch_emissions.shape) # (200,1000) and (200,1000,2)

# Make a new Gaussian HMM and fit it with EM
params, props = hmm.initialize(key3, method="kmeans", emissions=batch_emissions)
params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20)

These examples demonstrate the dynamax models, but we can also call the low-level inference code directly.

Contributing

Please see this page for details on how to contribute.

About

Core team: Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, Kevin Murphy.

Other contributors: Adrien Corenflos, Elizabeth DuPre, Gerardo Duran-Martin, Colin Schlager, Libby Zhang and other people listed here

MIT License. 2022

More Repositories

1

pyprobml

Python code for "Probabilistic Machine learning" book by Kevin Murphy
Jupyter Notebook
6,499
star
2

pml-book

"Probabilistic Machine Learning" - a book series by Kevin Murphy
Jupyter Notebook
4,932
star
3

pmtk3

Probabilistic Modeling Toolkit for Matlab/Octave.
HTML
1,546
star
4

pml2-book

Probabilistic Machine Learning: Advanced Topics
1,389
star
5

probml-notebooks

Notebooks for "Probabilistic Machine Learning" book
Jupyter Notebook
202
star
6

sts-jax

Structural Time Series in JAX
Jupyter Notebook
182
star
7

ssm-book

Interactive textbook on state-space models
Jupyter Notebook
172
star
8

bandits

Bayesian Bandits
Jupyter Notebook
64
star
9

rebayes

Recursive Bayesian Estimation (Sequential / Online Inference)
Jupyter Notebook
57
star
10

pmtkdata

A collection of MATLAB data sets used by PMTK.
MATLAB
57
star
11

pmtksupport

Various packages used by PMTK.
MATLAB
54
star
12

JSL

Jax SSM Library
Python
51
star
13

jprobml

Julia code for Probabilistic Machine Learning
Julia
37
star
14

probml-utils

Utilities for probabilistic ML
Python
32
star
15

probml-data

Datasets associated with pyprobml
Jupyter Notebook
19
star
16

pgm-jax

Probabilistic Graphical Models in JAX
Jupyter Notebook
14
star
17

pmtk1

A probabilistic modeling toolkit for Matlab/Octave. (Deprecated/old version.)
MATLAB
9
star
18

shift-happens

Research code for ML with distribution shift
Jupyter Notebook
8
star
19

superimport

Simple package to lookup missing packages and autoinstall them.
Python
7
star
20

sequential-neural-testbed

Sequential neural testbed
Jupyter Notebook
7
star
21

shifty

Distribution Shift
Jupyter Notebook
6
star
22

covid19

Covid19 modeling experiments
Jupyter Notebook
5
star
23

pmtk2

A probabilistic modeling toolkit for Matlab/Octave. (Deprecated/old version.)
MATLAB
4
star
24

deimport

Python
4
star
25

colab_powertoys

A set of python functions that enhances your experience with Google's Colab (Not a Google Project)
Python
4
star
26

probml.github.io

HTML
4
star
27

chest_xray_kaggle

derived from https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
3
star
28

gan-zoo

Generative Adversarial Networks for images
1
star
29

bic

Python
1
star
30

vae-zoo

Variational Autoencoders for images
1
star