• Stars
    star
    311
  • Rank 134,521 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 5 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Pytorch Implementation of MuZero

muzero-pytorch

Pytorch Implementation of MuZero : "Mastering Atari , Go, Chess and Shogi by Planning with a Learned Model" based on pseudo-code provided by the authors

Note: This implementation has just been tested on CartPole-v1 and would required modifications(in config folder) for other environments

Installation

  • Python 3.8, 3.9
  •   cd muzero-pytorch
      pip install -r requirements.txt

Usage:

  • Train: python main.py --env CartPole-v1 --case classic_control --opr train --force
  • Test: python main.py --env CartPole-v1 --case classic_control --opr test
  • Visualize results :
    • tensorboard --logdir=<result_dir_path>
    • if --use_wandb was passed, you can visualize results in wandb as well.
Required Arguments Description
--env Name of the environment
--case {atari,classic_control,box2d} It's used for switching between different domains(default: None)
--opr {train,test} select the operation to be performed
Optional Arguments Description
--value_loss_coeff Scale for value loss (default: None)
--revisit_policy_search_rate Rate at which target policy is re-estimated (default:None)( only valid if --use_target_model is enabled)
--use_priority Uses priority for data sampling in replay buffer. Also, priority for new data is calculated based on loss (default: False)
--use_max_priority Forces max priority assignment for new incoming data in replay buffer (only valid if --use_priority is enabled) (default: False)
--use_target_model Use target model for bootstrap value estimation (default: False)
--result_dir Directory Path to store results (defaut: current working directory)
--no_cuda no cuda usage (default: False)
--no_mps no mps (Metal Performance Shaders) usage (default: False)
--debug If enables, logs additional values (default:False)
--render Renders the environment (default: False)
--force Overrides past results (default: False)
--seed seed (default: 0)
--num_actors Number of actors running concurrently (default: 32)
--test_episodes Evaluation episode count (default: 10)
--use_wandb Logs console and tensorboard data on wandb (default: False)

Note: default: None => Values are loaded from the corresponding config

Training

CartPole-v1

  • Curves represents model evaluation for 5 episodes at 100 step training interval.
  • Also, each curve is a mean scores over 5 runs (seeds : [0,100,200,300,400])

More Repositories

1

ma-gym

A collection of multi agent environments based on OpenAI gym.
Python
477
star
2

mmn

Moore Machine Networks (MMN): Learning Finite-State Representations of Recurrent Policy Networks
Python
47
star
3

minimal-marl

Minimal implementation of multi-agent reinforcement learning algorithms
Python
43
star
4

visTorch

Interacting with Latent Space of AutoEncoder
Python
21
star
5

dream-and-search

Code for "Dream and Search to Control: Latent Space Planning for Continuous Control"
Python
10
star
6

conformal

Conformal prediction is a framework for providing accuracy guarantees on the predictions of a base predictor
Python
9
star
7

gym-cartpole-continuous

CartPole env. with continuous action space
Python
7
star
8

marl-pytorch

Pytorch Implementations of Multi Agent Reinforcement Learning(marl) algorithms
Python
5
star
9

gym_x

Gym environments for capture properties of hidden states(hx) of recurrent networks.
Python
5
star
10

policybazaar

A collection of multi-quality policies for continuous control tasks.
Python
3
star
11

opcc

Benchmark for "Offline Policy Comparison with Confidence"
Python
3
star
12

deep-conformal

Applying Conformal Prediction over Deep Neural Nets
Python
3
star
13

variable-td3

Learning n-step actions for control tasks
Python
2
star
14

pfa

Policy Fusion Architecture (PFA): We investigate policy gradient approaches for reward decomposition in reinforcement Learning
Python
1
star
15

opcc-baselines

Baselines for "Offline Policy Comparison with Confidence"
Python
1
star
16

vpn

PyTorch implementation of Value Prediction Network (VPN) 🚧 👷
Python
1
star