• Stars
    star
    235
  • Rank 171,079 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Pytorch implementation of Dreamer-v2: Visual Model Based RL Algorithm.

Dreamer- v2 Pytorch

Pytorch implementation of Mastering Atari with Discrete World Models

Installation

Dependencies:

I have added requirements.txt using conda list -e > requirements.txt and environment.yml using conda env export > environment.yml from my own conda environment.
I think it is easier to create a new conda environment(or venv etc.) and manually install the above listed few dependencies one by one.

Running experiments

  1. In tests folder, mdp.py and pomdp.py have been setup for experiments with MinAtar environments. All default hyper-parameters used are stored in a dataclass in config.py. To run dreamerv2 with default HPs on POMDP breakout and cuda :
python pomdp.py --env breakout --device cuda
  • Training curves are logged using wandb.
  • A results folder will be created locally to store models while training: test/results/env_name+'_'+env_id+'_'+pomdp/models
  1. Experimenting on other environments(using gym-api) can be done by adding another hyper-parameter dataclass in config.py.

Evaluating saved models

Trained models for all 5 games (mdp and pomdp version of each) are uploaded to the drive link: link (64 MBs)
Download and unzip the models inside /test directory.

Evaluate the saved model for POMDP version of breakout environment for 5 episodes, without rendering:

python eval.py --env breakout --eval_episode 5 --eval_render 0 --pomdp 1

Evaluation Results

Average evaluation score(over 50 evaluation episodes) of models saved at every 0.1 million frames. Green curves correspond to agent which have access to complete information, while red curves correspond to agents trained with partial observability.

In freeway, the agent gets stuck in a local maxima, wherein it learns to always move forward. The reason being that it is not penalised for crashing into cars. Probably due to policy entropy regularisation, its returns drop drastically around the 1 million frame mark, and gradually improve while maintaing the policy entropy.

Training curves

All experiments were logged using wandb. Training runs for all MDP and POMDP variants of MinAtar environments can be found on the wandb project page.

Please create an issue if you find a bug or have any queries.

Code structure:

  • test
    • pomdp.py run MinAtar experiments with partial observability.
    • mdp.py run MinAtar experiments with complete observability.
    • eval.y evaluate saved agents.
  • dreamerv2 dreamerv2 plus dreamerv1 and their combinations.
    • models neural network models.
      • actor.py discrete action model.
      • dense.py fully connected neural networks.
      • pixel.py convolutional encoder and decoder.
      • rssm.py recurrent state space model.
    • training
      • config.py hyper-parameter dataclass.
      • trainer.py training class, loss calculation.
      • evaluator.py evaluation class.
    • utils
      • algorithm.py lambda return function.
      • buffer.py replay buffers, batches of sequences.
      • module.py neural network parameters utils.
      • rssm.py recurrent state space model utils.
      • wrapper.py gym api and pomdp wrappers for MinAtar.

Hyper-Parameter description:

  • train_every: number of frames to skip while training.
  • collect_intervals: number of batches to be sampled from buffer, at every "train-every" iteration.
  • seq_len: length of trajectory sequence to be sampled from buffer.
  • embedding_size: size of embedding vector that is output by observation encoder.
  • rssm_type: categorical or gaussian random variables for stochastic states.
  • rssm_node_size: size of hidden layers of temporal posteriors and priors.
  • deter_size: size of deterministic part of recurrent state.
  • stoch_size: size of stochastic part of recurrent state.
  • class_size: number of classes for each categorical random variable
  • category_size: number of categorical random variables.
  • horizon: horizon for imagination in future latent state space.
  • kl_balance_scale: scale for kl balancing.
  • actor_entropy_scale: scale for policy entropy regularization in latent state space.

Acknowledgments

Awesome Environments used for testing:

This code is heavily inspired by the following works: