• Stars
    star
    327
  • Rank 128,686 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 1 year ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"

TD-MPC2

Official implementation of

TD-MPC2: Scalable, Robust World Models for Continuous Control by

Nicklas Hansen, Hao Su*, Xiaolong Wang* (UC San Diego)


[Website] [Paper] [Models] [Dataset]


Overview

TD-MPC2 is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across 104 continuous control tasks spanning multiple domains, with a single set of hyperparameters (right). We further demonstrate the scalability of TD-MPC2 by training a single 317M parameter agent to perform 80 tasks across multiple domains, embodiments, and action spaces (left).


This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC2 agents. We additionally open-source 300+ model checkpoints (including 12 multi-task models) across 4 task domains: DMControl, Meta-World, ManiSkill2, and MyoSuite, as well as our 30-task and 80-task datasets used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.


Getting started

You will need a machine with a GPU and at least 12 GB of RAM for single-task online RL with TD-MPC2, and 128 GB of RAM for multi-task offline RL on our provided 80-task dataset. A GPU with at least 8 GB of memory is recommended for single-task online RL and for evaluation of the provided multi-task models (up to 317M parameters). Training of the 317M parameter model requires a GPU with at least 24 GB of memory.

We provide a Dockerfile for easy installation. You can build the docker image by running

cd docker && docker build . -t <user>/tdmpc2:0.1.0

If you prefer to install dependencies manually, start by installing dependencies via conda by running one of the following commands:

conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml

The environment.yaml file installs dependencies required for all environments, whereas environment_minimal.yaml only installs dependencies for training on DMControl tasks.

If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running

python -m mani_skill2.utils.download_asset all

which downloads assets to ./data. You may move these assets to any location. Then, add the following line to your ~/.bashrc:

export MS2_ASSET_DIR=<path>/<to>/<data>

and restart your terminal. Meta-World additionally requires MuJoCo 2.1.0. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at https://www.tdmpc2.com/files/mjkey.txt. You can download the license by running

wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt

See docker/Dockerfile for installation instructions if you do not already have MuJoCo 2.1.0 installed. MyoSuite requires gym==0.13.0 which is incompatible with Meta-World and ManiSkill2. Install separately with pip install myosuite if desired. Depending on your existing system packages, you may need to install other dependencies. See docker/Dockerfile for a list of recommended system packages.


Supported tasks

This codebase currently supports 104 continuous control tasks from DMControl, Meta-World, ManiSkill2, and MyoSuite. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:

domain task
dmcontrol dog-run
dmcontrol cheetah-run-backwards
metaworld mw-assembly
metaworld mw-pick-place-wall
maniskill pick-cube
maniskill pick-ycb
myosuite myo-key-turn
myosuite myo-key-turn-hard

which can be run by specifying the task argument for evaluation.py. Multi-task training and evaluation is specified by setting task=mt80 or task=mt30 for the 80-task and 30-task sets, respectively.

As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks; use argument obs=rgb if you wish to train visual policies.

Example usage

We provide examples on how to evaluate our provided TD-MPC2 checkpoints, as well as how to train your own TD-MPC2 agents, below.

Evaluation

See below examples on how to evaluate downloaded single-task and multi-task checkpoints.

$ python evaluate.py task=mt80 model_size=48 checkpoint=/path/to/mt80-48M.pt
$ python evaluate.py task=mt30 model_size=317 checkpoint=/path/to/mt30-317M.pt
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true

All single-task checkpoints expect model_size=5. Multi-task checkpoints are available in multiple model sizes. Available arguments are model_size={1, 5, 19, 48, 317}. Note that single-task evaluation of multi-task checkpoints is currently not supported. See config.yaml for a full list of arguments.

Training

See below examples on how to train TD-MPC2 on a single task (online RL) and on multi-task datasets (offline RL). We recommend configuring Weights and Biases (wandb) in config.yaml to track training progress.

$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb

We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (model_size=5). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are model_size={1, 5, 19, 48, 317}. See config.yaml for a full list of arguments.


Citation

If you find our work useful, please consider citing the paper as follows:

@misc{hansen2023tdmpc2,
	title={TD-MPC2: Scalable, Robust World Models for Continuous Control}, 
	author={Nicklas Hansen and Hao Su and Xiaolong Wang},
	year={2023},
	eprint={2310.16828},
	archivePrefix={arXiv},
	primaryClass={cs.LG}
}

Contributing

You are very welcome to contribute to this project. Feel free to open an issue or pull request if you have any suggestions or bug reports, but please review our guidelines first. Our goal is to build a codebase that can easily be extended to new environments and tasks, and we would love to hear about your experience!


License

This project is licensed under the MIT License - see the LICENSE file for details. Note that the repository relies on third-party code, which is subject to their respective licenses.

More Repositories

1

tdmpc

Code for "Temporal Difference Learning for Model Predictive Control"
Python
352
star
2

rnn_lstm_from_scratch

How to build RNNs and LSTMs from scratch with NumPy.
Jupyter Notebook
247
star
3

voice-activity-detection

Voice Activity Detection (VAD) using deep learning.
Jupyter Notebook
190
star
4

dmcontrol-generalization-benchmark

DMControl Generalization Benchmark
Python
165
star
5

puppeteer

Code for "Hierarchical World Models as Visual Whole-Body Humanoid Controllers"
Python
140
star
6

policy-adaptation-during-deployment

Training code and evaluation benchmarks for the "Self-Supervised Policy Adaptation during Deployment" paper.
Python
111
star
7

neural-net-optimization

PyTorch implementations of recent optimization algorithms for deep learning.
Python
61
star
8

minimal-nas

Minimal implementation of a Neural Architecture Search system.
Python
36
star
9

svea-vit

Code for the paper "Stabilizing Deep Q-Learning with ConvNets and Vision Transformers under Data Augmentation"
Python
17
star
10

adaptive-learning-rate-schedule

PyTorch implementation of the "Learning an Adaptive Learning Rate Schedule" paper found here: https://arxiv.org/abs/1909.09712.
Python
10
star
11

nicklashansen.github.io

Repository for my personal site https://nicklashansen.github.io/, built with plain html.
HTML
9
star
12

a3c

Asynchronous Advantage Actor-Critic using Generalized Advantage Estimation (PyTorch)
Python
8
star
13

smallrl

Personal repository for quick RL prototyping. Work in progress!
Python
3
star
14

docker-from-conda

Builds a docker image from a conda environment.yml file.
Dockerfile
3
star
15

music-genre-classification

Exam project on Audio Features for Music Genre Classification for course 02452 Audio Information Processing Systems at Technical University of Denmark (DTU).
Jupyter Notebook
1
star
16

bachelor-thesis

Repository for bachelor thesis on Automatic Multi-Modal Detection of Autonomic Arousals in Sleep. The thesis itself and all related data is confidential and thus not publicly available, but access to the thesis can be granted by sending a request to [email protected].
Python
1
star
17

reinforcement-learning-sutton-barto

Personal repository for course on reinforcement learning. Includes implementations of various problems from the Reinforcement Learning: An Introduction book by R. Sutton and A. Barto.
Jupyter Notebook
1
star
18

nautilus-launcher

Minimal launcher for Nautilus
Python
1
star