• Stars
    star
    171
  • Rank 222,266 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A differentiable cosmology library in JAX

jax-cosmo

Join the chat at https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo Documentation Status CI Test black PyPI Contributor Covenant PyPI - License All Contributors

Finally a differentiable cosmology library, and it's in JAX!

Have a look at the GitHub issues to see what is needed or if you have any thoughts on the design, and don't hesitate to join the Gitter room for discussions.

TL;DR

This is what jax-cosmo aims to do:

...
def likelihood(cosmo):
  # Compute mean and covariance of angular Cls, for specific probes
  mu, cov = jax_cosmo.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, probes)
  # Return likelihood value
  return jax_cosmo.likelihood.gaussian_log_likelihood(data, mu, cov)

# Compute derivatives of the likelihood with respect to cosmological parameters
g = jax.grad(likelihood)(cosmo)

# Compute Fisher matrix of cosmological parameters
F = - jax.hessian(likelihood)(cosmo)

This is how you can compute gradients and hessians of any functions in jax-cosmo, all of this without any finite differences.

Check out a full example here: colab link

Have a look at the design document to learn more about the structure of the code.

What is JAX?

JAX = NumPy + autodiff + GPU

JAX is a framework for automatic differentiation (like TensorFlow or PyTorch) but following the NumPy API, and using the GPU/TPU enable XLA backend.

What does that mean?

  • You write plain Python/NumPy code, no need to learn a different language
  • It runs on GPU, you don't need to do anything particular
  • You can take derivatives of any quantity with respect to any parameters by automatic differentiation.

Checkout the JAX project page to learn more!

Install

jax-cosmo is pure Python, so installing is a breeze:

$ pip install jax-cosmo

Philosophy

Here are some of the design guidelines:

  • Implementation of equations should be human readable, and documentation should always live next to the implementation.
  • Should always be trivially installable: external dependencies should be kept to a minimum, especially the ones that require compilation or with restrictive licenses.
  • Keep API and implementation simple and intuitive, minimize user and developer surprise.
  • β€œDebugging is twice as hard as writing the code in the first place. Therefore, if you write the code as cleverly as possible, you are, by definition, not smart enough to debug it.” -Brian Kernighan, quote stolen from here.

Contributing

jax-cosmo aims to be a community effort, contributions are most welcome and can come in several forms

  • Bug reports
  • API design suggestions
  • (Pull) requests for more features
  • Examples and notebooks of cool things that can be done with the code

You can chime-in on any aspects of the design by proposing a PR to the design document. The issue page is a good place to start, but don't hesitate to come chat in the Gitter room.

Please take a look at the Contributing Document for more information.

This project follows the All Contributors guidelines aiming at recognizing and valorizing contributions at any levels.

Contributors ✨

Thanks goes to these wonderful people (emoji key):


Francois Lanusse

πŸ’»

Santiago Casas

πŸ› πŸ’»

Austin Peel

πŸ’»

Minas Karamanis

πŸ’»

David Kirkby

πŸ’» πŸ›

Alexandre Boucaud

πŸ’»

Denise Lanzieri

πŸ’»

jecampagne

πŸ›

Yin Li

πŸ’» πŸ›

This project follows the all-contributors specification. Contributions of any kind welcome!