Reinforcement Learning Environments in JAX ๐
Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax
brings the power of jit
and vmap
/pmap
to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax
allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax
). We provide training & checkpoints for both PPO & ES in gymnax-blines
. Get started here ๐ .
gymnax
API Usage ๐ฒ
Basic import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)
# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")
# Reset the environment.
obs, state = env.reset(key_reset, env_params)
# Sample a random action.
action = env.action_space(env_params).sample(key_act)
# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
Implemented Accelerated Environments ๐๏ธ
* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit
compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines
documentation.
Installation โณ
The latest gymnax
release can directly be installed from PyPI:
pip install gymnax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/gymnax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
Examples ๐
- ๐ Environment API - Get started with the basic
gymnax
API. - ๐ Distributed Anakin Agent - Train an Anakin (Hessel et al., 2021) agent on
SpaceInvaders-MinAtar
. - ๐ ES with
gymnax
- Meta-evolve an LSTM controller that controls 2 link pendula of different lengths. - ๐ Bandit A2C Meta-RL - Meta-learn an A2C LSTM that learns to explore/exploit in multi-arm bandit tasks.
- ๐ Trained baselines - Check out the trained baseline agents (PPO/ES) in
gymnax-blines
.
Key Selling Points ๐ต
-
Environment vectorization & acceleration: Easy composition of JAX primitives (e.g.
jit
,vmap
,pmap
):# Jit-accelerated step transition jit_step = jax.jit(env.step) # map (vmap/pmap) across random keys for batch rollouts reset_rng = jax.vmap(env.reset, in_axes=(0, None)) step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None)) # map (vmap/pmap) across env parameters (e.g. for meta-learning) reset_params = jax.vmap(env.reset, in_axes=(None, 0)) step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))
For speed comparisons with standard vectorized NumPy environments check out
gymnax-blines
. -
Scan through entire episode rollouts: You can also
lax.scan
through entirereset
,step
episode loops for fast compilation:def rollout(rng_input, policy_params, env_params, steps_in_episode): """Rollout a jitted gymnax episode with lax.scan.""" # Reset the environment rng_reset, rng_episode = jax.random.split(rng_input) obs, state = env.reset(rng_reset, env_params) def policy_step(state_input, tmp): """lax.scan compatible step transition in jax env.""" obs, state, policy_params, rng = state_input rng, rng_step, rng_net = jax.random.split(rng, 3) action = model.apply(policy_params, obs) next_obs, next_state, reward, done, _ = env.step( rng_step, state, action, env_params ) carry = [next_obs, next_state, policy_params, rng] return carry, [obs, action, reward, next_obs, done] # Scan over episode step loop _, scan_out = jax.lax.scan( policy_step, [obs, state, policy_params, rng_episode], (), steps_in_episode ) # Return masked sum of rewards accumulated by agent in episode obs, action, reward, next_obs, done = scan_out return obs, action, reward, next_obs, done
-
Build-in visualization tools: You can also smoothly generate GIF animations using the
Visualizer
tool, which covers allclassic_control
,MinAtar
and mostmisc
environments:from gymnax.visualize import Visualizer state_seq, reward_seq = [], [] rng, rng_reset = jax.random.split(rng) obs, env_state = env.reset(rng_reset, env_params) while True: state_seq.append(env_state) rng, rng_act, rng_step = jax.random.split(rng, 3) action = env.action_space(env_params).sample(rng_act) next_obs, next_env_state, reward, done, info = env.step( rng_step, env_state, action, env_params ) reward_seq.append(reward) if done: break else: obs = next_obs env_state = next_env_state cum_rewards = jnp.cumsum(jnp.array(reward_seq)) vis = Visualizer(env, env_params, state_seq, cum_rewards) vis.animate(f"docs/anim.gif")
-
Training pipelines & pretrained agents: Check out
gymnax-blines
for trained agents, expert rollout visualizations and PPO/ES pipelines. The agents are minimally tuned, but can help you get up and running. -
Simple batch agent evaluation: Work-in-progress.
from gymnax.experimental import RolloutWrapper # Define rollout manager for pendulum env manager = RolloutWrapper(model.apply, env_name="Pendulum-v1") # Simple single episode rollout for policy obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params) # Multiple rollouts for same network (different rng, e.g. eval) rng_batch = jax.random.split(rng, 10) obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout( rng_batch, policy_params ) # Multiple rollouts for different networks + rng (e.g. for ES) batch_params = jax.tree_map( # Stack parameters or use different lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params ) obs, action, reward, next_obs, done, cum_ret = manager.population_rollout( rng_batch, batch_params )
Resources & Other Great Tools ๐
- ๐ป Brax: JAX-based library for rigid body physics by Google Brain with JAX-style MuJoCo substitutes.
- ๐ป envpool: Vectorized parallel environment execution engine.
- ๐ป Jumanji: A suite of diverse and challenging RL environments in JAX.
- ๐ป Pgx: JAX-based classic board game environments.
gymnax
โ๏ธ
Acknowledgements & Citing If you use gymnax
in your research, please cite it as follows:
@software{gymnax2022github,
author = {Robert Tjarko Lange},
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
url = {http://github.com/RobertTLange/gymnax},
version = {0.0.4},
year = {2022},
}
We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.
Development ๐ท
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing ๐ค.