• Stars
    star
    630
  • Rank 71,328 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 4 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

RL Environments in JAX ๐ŸŒ


Reinforcement Learning Environments in JAX ๐ŸŒ

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax). We provide training & checkpoints for both PPO & ES in gymnax-blines. Get started here ๐Ÿ‘‰ Colab.

Basic gymnax API Usage ๐Ÿฒ

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments ๐ŸŽ๏ธ

Environment Name Reference Source ๐Ÿค– Ckpt (Return) Secs/1M ๐Ÿฆถ
A100 (2k ๐ŸŒŽ)
Acrobot-v1 Brockman et al. (2016) Click PPO, ES (R: -80) 0.07
Pendulum-v1 Brockman et al. (2016) Click PPO, ES (R: -130) 0.07
CartPole-v1 Brockman et al. (2016) Click PPO, ES (R: 500) 0.05
MountainCar-v0 Brockman et al. (2016) Click PPO, ES (R: -118) 0.07
MountainCarContinuous-v0 Brockman et al. (2016) Click PPO, ES (R: 92) 0.09
Asterix-MinAtar Young & Tian (2019) Click PPO (R: 15) 0.92
Breakout-MinAtar Young & Tian (2019) Click PPO (R: 28) 0.19
Freeway-MinAtar Young & Tian (2019) Click PPO (R: 58) 0.87
SpaceInvaders-MinAtar Young & Tian (2019) Click PPO (R: 131) 0.33
Catch-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.15
DeepSea-bsuite Osband et al. (2019) Click PPO, ES (R: 0) 0.22
MemoryChain-bsuite Osband et al. (2019) Click PPO, ES (R: 0.1) 0.13
UmbrellaChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1) 0.08
DiscountingChain-bsuite Osband et al. (2019) Click PPO, ES (R: 1.1) 0.06
MNISTBandit-bsuite Osband et al. (2019) Click - -
SimpleBandit-bsuite Osband et al. (2019) Click - -
FourRooms-misc Sutton et al. (1999) Click PPO, ES (R: 1) 0.07
MetaMaze-misc Micconi et al. (2020) Click ES (R: 32) 0.09
PointRobot-misc Dorfman et al. (2021) Click ES (R: 10) 0.08
BernoulliBandit-misc Wang et al. (2017) Click ES (R: 90) 0.08
GaussianBandit-misc Lange & Sprekeler (2022) Click ES (R: 0) 0.07
Reacher-misc Lenton et al. (2021) Click
Swimmer-misc Lenton et al. (2021) Click
Pong-misc Kirsch (2018) Click

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines documentation.

Installation โณ

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples ๐Ÿ“–

Key Selling Points ๐Ÿ’ต

  • Environment vectorization & acceleration: Easy composition of JAX primitives (e.g. jit, vmap, pmap):

    # Jit-accelerated step transition
    jit_step = jax.jit(env.step)
    
    # map (vmap/pmap) across random keys for batch rollouts
    reset_rng = jax.vmap(env.reset, in_axes=(0, None))
    step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None))
    
    # map (vmap/pmap) across env parameters (e.g. for meta-learning)
    reset_params = jax.vmap(env.reset, in_axes=(None, 0))
    step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))

    For speed comparisons with standard vectorized NumPy environments check out gymnax-blines.

  • Scan through entire episode rollouts: You can also lax.scan through entire reset, step episode loops for fast compilation:

    def rollout(rng_input, policy_params, env_params, steps_in_episode):
        """Rollout a jitted gymnax episode with lax.scan."""
        # Reset the environment
        rng_reset, rng_episode = jax.random.split(rng_input)
        obs, state = env.reset(rng_reset, env_params)
    
        def policy_step(state_input, tmp):
            """lax.scan compatible step transition in jax env."""
            obs, state, policy_params, rng = state_input
            rng, rng_step, rng_net = jax.random.split(rng, 3)
            action = model.apply(policy_params, obs)
            next_obs, next_state, reward, done, _ = env.step(
                rng_step, state, action, env_params
            )
            carry = [next_obs, next_state, policy_params, rng]
            return carry, [obs, action, reward, next_obs, done]
    
        # Scan over episode step loop
        _, scan_out = jax.lax.scan(
            policy_step,
            [obs, state, policy_params, rng_episode],
            (),
            steps_in_episode
        )
        # Return masked sum of rewards accumulated by agent in episode
        obs, action, reward, next_obs, done = scan_out
        return obs, action, reward, next_obs, done
  • Build-in visualization tools: You can also smoothly generate GIF animations using the Visualizer tool, which covers all classic_control, MinAtar and most misc environments:

    from gymnax.visualize import Visualizer
    
    state_seq, reward_seq = [], []
    rng, rng_reset = jax.random.split(rng)
    obs, env_state = env.reset(rng_reset, env_params)
    while True:
        state_seq.append(env_state)
        rng, rng_act, rng_step = jax.random.split(rng, 3)
        action = env.action_space(env_params).sample(rng_act)
        next_obs, next_env_state, reward, done, info = env.step(
            rng_step, env_state, action, env_params
        )
        reward_seq.append(reward)
        if done:
            break
        else:
          obs = next_obs
          env_state = next_env_state
    
    cum_rewards = jnp.cumsum(jnp.array(reward_seq))
    vis = Visualizer(env, env_params, state_seq, cum_rewards)
    vis.animate(f"docs/anim.gif")
  • Training pipelines & pretrained agents: Check out gymnax-blines for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running.

  • Simple batch agent evaluation: Work-in-progress.

    from gymnax.experimental import RolloutWrapper
    
    # Define rollout manager for pendulum env
    manager = RolloutWrapper(model.apply, env_name="Pendulum-v1")
    
    # Simple single episode rollout for policy
    obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params)
    
    # Multiple rollouts for same network (different rng, e.g. eval)
    rng_batch = jax.random.split(rng, 10)
    obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(
        rng_batch, policy_params
    )
    
    # Multiple rollouts for different networks + rng (e.g. for ES)
    batch_params = jax.tree_map(  # Stack parameters or use different
        lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params
    )
    obs, action, reward, next_obs, done, cum_ret = manager.population_rollout(
        rng_batch, batch_params
    )

Resources & Other Great Tools ๐Ÿ“

  • ๐Ÿ’ป Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
  • ๐Ÿ’ป envpool: Vectorized parallel environment execution engine.
  • ๐Ÿ’ป Jumanji: A suite of diverse and challenging RL environments in JAX.
  • ๐Ÿ’ป Pgx: JAX-based classic board game environments.

Acknowledgements & Citing gymnax โœ๏ธ

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development ๐Ÿ‘ท

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing ๐Ÿค—.

More Repositories

1

evosax

Evolution Strategies in JAX ๐ŸฆŽ
Python
475
star
2

code-and-blog

Some small scale experiments for my blog posts ๐Ÿ“
Jupyter Notebook
78
star
3

gymnax-blines

Baselines for gymnax ๐Ÿค–
Jupyter Notebook
57
star
4

reading-notes-ml

Progress, Notes, Summaries and a lot of Questions on Machine Learning
55
star
5

spinningup-workspace

Reading notes & PyTorch experiments on OpenAI's "Spinning Up in DRL" tutorial.
Python
36
star
6

deep-rl-tutorial

A Tutorial on Deep Reinforcement Learning in PyTorch
Jupyter Notebook
29
star
7

flexible-learning-group

A curated list of papers presented in the ๐Ÿ“–"Flexible Learning Reading Group" @ TU Berlin. Join us! ๐Ÿค—
27
star
8

es-lottery

Lottery Tickets in Evolutionary Optimization (Lange & Sprekeler, ICML 2023)
Jupyter Notebook
13
star
9

minimal-meta-rl

Minimal A2C/A3C example of an LSTM-based meta-learner.
Python
13
star
10

gym-hanoi

A Towers of Hanoi environment in OpenAI Gym Style
Python
12
star
11

SequentialBayesianLearning

Sequential Bayesian Learning Agents learning data-generating process of binary sequence.
Jupyter Notebook
11
star
12

automata-perturbation-lstm

Code accompanying Gรณmez-Nava et al. (2023, Nature Physics)
Python
10
star
13

gym-swarm

A Swarm environment in OpenAI gym style
Jupyter Notebook
8
star
14

StochVol_HMM

Stochastic Volatility Modelling using Hidden Markov Model
Jupyter Notebook
6
star
15

algonauts-2021

Algonauts 2021 Challenge Mini-Track submission based on SimCLR-v2 features & Bayesian Optimization
Jupyter Notebook
6
star
16

action-grammars-hrl

Action Grammars for Hierarchical Reinforcement Learning
Jupyter Notebook
5
star
17

Reading_Notes_Neuro

Progress, Notes, Summaries and a lot of Questions on Neuroscience
5
star
18

evojax-benchmarks

Benchmarking Utilities for EvoJAX
Jupyter Notebook
4
star
19

Bio-Plausible-DeepLearning

Who needs backprop? - Guerguiev et al (2017) - Reproduction & Robustness Checks
Jupyter Notebook
3
star
20

RobertTLange

2
star
21

ModelsNeuralSystems

Computer Practical Coursework for BCCN Berlin course "Models of Neural Systems" (2018/2019)
Jupyter Notebook
2
star
22

Migration_and_Technology_Diffusion

Bachelor Thesis tackling the Impact of Skilled-Worker Immigration on the Diffusion of New Technologies
Stata
1
star
23

RandGLM

RandNLA methods for estimating GLMs with Big Datasets
R
1
star