• Stars
    star
    168
  • Rank 225,507 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Sparse nonlinear least squares for JAX

jaxfg

build lint mypy codecov

jaxfg is a factor graph-based nonlinear least squares library for JAX. Typical applications include sensor fusion, SLAM, bundle adjustment, optimal control.

The premise: we provide a high-level interface for defining probability densities as factor graphs. MAP inference reduces to nonlinear optimization, which we accelerate by analyzing the structure of the graph. Repeated factor and variable types have operations vectorized, and the sparsity of graph connections is translated into sparse matrix operations.

Features:

  • Autodiff-powered sparse Jacobians.
  • Automatic vectorization for repeated factor and variable types.
  • Manifold definition interface, with implementations provided for SO(2), SE(2), SO(3), and SE(3) Lie groups.
  • Support for standard JAX function transformations: jit, vmap, pmap, grad, etc.
  • Nonlinear optimizers: Gauss-Newton, Levenberg-Marquardt, Dogleg.
  • Sparse linear solvers: conjugate gradient (Jacobi-preconditioned), sparse Cholesky (via CHOLMOD).

This library is released as part of our IROS 2021 paper (more info in our core experiment repository here) and borrows heavily from a wide set of existing libraries, including GTSAM, Ceres Solver, minisam, SwiftFusion, and g2o. For technical background and concepts, GTSAM has a great set of tutorials.

Installation

scikit-sparse require SuiteSparse:

sudo apt update
sudo apt install -y libsuitesparse-dev

Then, from your environment of choice:

git clone https://github.com/brentyi/jaxfg.git
cd jaxfg
pip install -e .

Example scripts

Toy pose graph optimization:

python scripts/pose_graph_simple.py

Pose graph optimization from .g2o files:

python scripts/pose_graph_g2o.py  # For options, pass in a --help flag

Development

If you're interested in extending this library to define your own factor graphs, we'd recommend first familiarizing yourself with:

  1. Pytrees in JAX: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
  2. Python dataclasses: https://docs.python.org/3/library/dataclasses.html
    • We currently take a "make everything a dataclass" philosophy for software engineering in this library. This is convenient for several reasons, but notably makes it easy for objects to be registered as pytree nodes. See jax_dataclasses for details on this.
  3. Type annotations: https://docs.python.org/3/library/typing.html
    • We rely on generics (typing.Generic and typing.TypeVar) particularly heavily. If you're familiar with C++ this should come very naturally (~templates).
  4. Explicit decorators for overrides/inheritance: https://github.com/mkorpela/overrides
    • The @overrides and @final decorators signal which methods are being and/or shouldn't be overridden. The same goes for @abc.abstractmethod.

From there, we have a few references for defining your own factor graphs, factors, and manifolds:

Current limitations

  1. In XLA, JIT compilation needs to happen for each unique set of input shapes. Modifying graph structures can thus introduce significant re-compilation overheads; this can restrict applications that are dynamic or online.
  2. Our marginalization implementation is not very good.

To-do

This library's still in development mode! Here's our TODO list:

  • Preliminary graph, variable, factor interfaces
  • Real vector variable types
  • Refactor into package
  • Nonlinear optimization for MAP inference
    • Conjugate gradient linear solver
    • CHOLMOD linear solver
      • Basic implementation. JIT-able, but no vmap, pmap, or autodiff support.
      • Custom VJP rule? vmap support?
    • Gauss-Newton implementation
    • Termination criteria
    • Damped least squares
    • Dogleg
    • Inexact Newton steps
    • Revisit termination criteria
    • Reduce redundant code
    • Robust losses
  • Marginalization
    • Prototype using sksparse/CHOLMOD (works but fairly slow)
    • JAX implementation?
  • Validate g2o example
  • Performance
    • More intentional JIT compilation
    • Re-implement parallel factor computation
    • Vectorized linearization
    • Basic (Jacobi) CGLS preconditioning
  • Manifold optimization (mostly offloaded to jaxlie)
    • Basic interface
    • Manifold optimization on SO2
    • Manifold optimization on SE2
    • Manifold optimization on SO3
    • Manifold optimization on SE3
  • Usability + code health (low priority)
    • Basic cleanup/refactor
      • Better parallel factor interface
      • Separate out utils, lie group helpers
      • Put things in folders
    • Resolve typing errors
    • Cleanup/refactor (more)
    • Package cleanup: dependencies, etc
    • Add CI:
      • mypy
      • lint
      • build
      • coverage
    • More comprehensive tests
    • Clean up docstrings
    • New name

More Repositories

1

tyro

CLI interfaces & config objects, from types
Python
462
star
2

jaxlie

Rigid transforms + Lie groups in JAX
Python
220
star
3

tilted

Canonical Factors for Hybrid Neural Fields @ ICCV 2023
Python
101
star
4

dfgo

Differentiable Factor Graph Optimization for Learning Smoothers @ IROS 2021
Python
78
star
5

jax_dataclasses

Pytrees + dataclasses ❀️
Python
59
star
6

tensorf-jax

Unofficial implementation of Tensorial Radiance Fields (Chen & Xu β€˜22)
Python
38
star
7

pips-jax

JAX port of Persistent Independent Particles
Python
36
star
8

multimodalfilter

Jupyter Notebook
27
star
9

jelly_mechanical

Solidworks files for ME135 quadruped project
21
star
10

jax-ekf

Generic EKF, with support for non-Euclidean manifolds
Python
20
star
11

isort.vim

Async isort plugin for Vim + Neovim
Vim Script
20
star
12

minGPT-flax

GPT implementation in Flax
Python
18
star
13

stl_web_viewer2

Javascript utility for embedding 3D models
JavaScript
12
star
14

sparky_firmware

Firmware for RHex-style robot
C++
9
star
15

brushless_driver

BLDC motor driver design; 48V, 3.5A continuous / 6.5A peak
HTML
7
star
16

dotfilesp

Configuration files & setup scripts
Vim Script
6
star
17

fannypack

Tools for training PyTorch models
Python
5
star
18

jax_cuda_boilerplate

Toy package for custom CUDA kernels + JAX
Python
5
star
19

fifteen

Python
4
star
20

as5047d_breakout

Absolute encoder breakout PCB for prototyping
HTML
3
star
21

drawing_machine_firmware

Firmware + G-code parser for plotter
C
3
star
22

jelly2_mechanical

2
star
23

as5048b_breakout

daisy chainable absolute magnetic encoder
HTML
2
star
24

stl_web_viewer

deprecated in favor of https://github.com/brentyi/stl_web_viewer2
JavaScript
2
star
25

jaxfg

new version: https://github.com/brentyi/jaxls
Python
2
star
26

crossmodal_filtering1

obsolete; see https://github.com/stanford-iprl-lab/torchfilter
Jupyter Notebook
1
star
27

digikey_parser

nodejs package for scraping part information from digikey barcodes
JavaScript
1
star
28

marshmello_web

HTML
1
star
29

opthex

Python
1
star
30

170_proj

Python
1
star
31

keyboard

Mechanical keyboard PCB designs
1
star
32

jax-ldr

(unofficial) Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction
Jupyter Notebook
1
star