• Stars
    star
    140
  • Rank 261,473 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.

Reinforcement Learning (PyTorch) 🤖 + 🍰 = ❤️

This repo will contain PyTorch implementation of various fundamental RL algorithms.
It's aimed at making it easy to start playing and learning about RL.

The problem I came across investigating other DQN projects is that they either:

  • Don't have any evidence that they've actually achieved the published results
  • Don't have a "smart" replay buffer (i.e. they allocate (1M, 4, 84, 84) ~ 28 GBs! instead of (1M, 84, 84) ~ 7 GB)
  • Lack of visualizations and debugging utils

This repo will aim to solve these problems.

Table of Contents

RL agents

DQN

This was the project that started the revolution in the RL world - deep Q-network (🔗 Mnih et al.),
aka "Human-level control through deep RL".

DQN model learned to play 29 Atari games (out of 49 they it tested on) on a super-human/comparable-to-humans level. Here is the schematic of it's CNN architecture:

The fascinating part is that it learned only from "high-dimensional" (84x84) images and (usually sparse) rewards. The same architecture was used for all of the 49 games - although the model has to be retrained, from scratch, every single time.

DQN current results

Since it takes lots of compute and time to train all of the 49 models I'll consider the DQN project completed once I succeed in achieving the published results on:

  • Breakout
  • Pong

Having said that the experiments are still in progress, so feel free to contribute!

  • For some reason the models aren't learning very well so if you find a bug open up a PR! ❤️
  • I'm also experiencing slowdowns - so any PRs that would improve/explain the perf are welcome!
  • If you decide to train the DQN using this repo on some other Atari game I'll gladly check-in your model!

Important note: please follow the coding guidelines of this repo before you submit a PR so that we can minimize the back-and-forth. I'm a decently busy guy as I assume you are.

Current results - Breakout

As you can see the model did learn something although it's far from being really good.

Current results - Pong

todo

Setup

Let's get this thing running! Follow the next steps:

  1. git clone https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning
  2. Open Anaconda console and navigate into project directory cd path_to_repo
  3. Run conda env create from project directory (this will create a brand new conda environment).
  4. Run activate pytorch-rl-env (for running scripts from your console or setup the interpreter in your IDE)

If you're on Windows you'll additionally need to install this: pip install https://github.com/Kojoley/atari-py/releases atary_py to install gym's Atari dependencies.

Otherwise this should do it pip install 'gym[atari]', if it's not working check out this and this.

That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.


PyTorch pip package will come bundled with some version of CUDA/cuDNN with it, but it is highly recommended that you install a system-wide CUDA beforehand, mostly because of the GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system. Follow through points 1 and 2 of this setup and use the most up-to-date versions of Miniconda and CUDA/cuDNN for your system.

Usage

Option 1: Jupyter Notebook

Coming soon.

Option 2: Use your IDE of choice

You just need to link the Python environment you created in the setup section.

Training DQN

To run with default settings just run python train_DQN_script.py.

Settings you'll want to experiment with:

  • --seed - it may just so happen that I've chosen a bad one (RL is very sensitive)
  • --learning_rate - DQN originally used RMSProp, I saw that Adam with 1e-4 worked for stable baselines 3
  • --grad_clipping_value - there was a lot of noise in the gradients so I used this to control it
  • Try using RMSProp (I haven't yet). Adam was an improvement over RMSProp so I doubt it's causing the issues

Less important settings for getting DQN to work:

  • --env_id - depending on which game you want to train on (I'd focus on the easiest one for now - Breakout)
  • --replay_buffer_size - hopefully you can train DQN with 1M, as in the original paper, if not make it smaller
  • --dont_crash_if_no_mem - add this flag if you want to run with 1M replay buffer even if you don't have enough RAM

The training script will:

  • Dump checkpoint *.pth models into models/checkpoints/
  • Dump the best (highest reward) *.pth model into models/binaries/ <- TODO
  • Periodically write some training metadata to the console
  • Save tensorboard metrics into runs/, to use it check out the visualization section

Visualization and debugging tools

You can visualize the metrics during the training, by calling tensorboard --logdir=runs from your console and pasting the http://localhost:6006/ URL into your browser.

I'm currently visualizing the Huber loss (and you can see there is something weird going on):

Rewards and steps taken per episode (there is a fair bit of correlation between these 2):

And gradient L2 norms of weights and biases of every CNN/FC layer as well as the complete grad vector:

As well as epsilon (from the epsilon-greedy algorithm) but that plot is not that informative so I'll omit it here.

As you can see the plots are super noisy! As I could have expected, but the progress just stagnates from certain point onwards and that's what I'm trying to debug atm.


To enter the debug mode add the --debug flag to your console or IDE's list of script arguments.

It'll visualize the current state that's being fed into the RL agent. Sometimes the state will have some black frames prepended since there aren't enough frames experienced in the current episode:

But mostly all of the 4 frames will be in there:

And it will start rendering the game frames (Pong and Breakout showed here from left to right):

Hardware requirements

You'll need some decent hardware to train the DQN in reasonable time so that you can iterate fast:

  1. 16+ GB of RAM (Replay Buffer takes around ~7 GBs of RAM).
  2. The faster your GPU is - the better! 😅 Having said that VRAM is not the bottleneck you'll need 2+ GB VRAM.

With 16 GB RAM and RTX 2080 it takes ~5 days to train DQN on my machine - I'm experiencing some slowdowns which I haven't debugged yet. Here is the FPS (frames-per-second) metric I'm logging:

The shorter, green one is the current experiment I'm running, the red one took over 5 days to train.

Future todos

  1. Debug DQN and achieve the published results
  2. Add Vanilla PG
  3. Add PPO

Learning material

Here are some videos I made on RL which may help you to better understand how DQN and other RL algorithms work:

DQN paper explained

And some other ones:

And in this one I tried to film through the process while the project was not nearly as polished as it is now:

I'll soon create a blog on how to get started with RL - so stay tuned for that!

Acknowledgements

I found these resources useful while developing this project, sorted (approximately) by usefulness:

Citation

If you find this code useful, please cite the following:

@misc{Gordić2021PyTorchLearnReinforcementLearning,
  author = {Gordić, Aleksa},
  title = {pytorch-learn-reinforcement-learning},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/pytorch-learn-reinforcement-learning}},
}

Licence

License: MIT

More Repositories

1

pytorch-GAT

My implementation of the original GAT paper (Veličković et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!
Jupyter Notebook
2,253
star
2

pytorch-original-transformer

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Jupyter Notebook
880
star
3

get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Jupyter Notebook
546
star
4

pytorch-GANs

My implementation of various GAN (generative adversarial networks) architectures like vanilla GAN (Goodfellow et al.), cGAN (Mirza et al.), DCGAN (Radford et al.), etc.
Python
366
star
5

Open-NLLB

Effort to open-source NLLB checkpoints.
Python
364
star
6

pytorch-deepdream

PyTorch implementation of DeepDream algorithm (Mordvintsev et al.). Additionally I've included playground.py to help you better understand basic concepts behind the algo.
Jupyter Notebook
352
star
7

pytorch-neural-style-transfer

Reconstruction of the original paper on neural style transfer (Gatys et al.). I've additionally included reconstruction scripts which allow you to reconstruct only the content or the style of the image - for better understanding of how NST works.
Python
343
star
8

stable_diffusion_playground

Playing around with stable diffusion. Generated images are reproducible because I save the metadata and latent information. You can generate and then later interpolate between the images of your choice.
Python
203
star
9

pytorch-neural-style-transfer-johnson

Reconstruction of the fast neural style transfer (Johnson et al.). Some portions of the paper have been improved by the follow-up work like the instance normalization, etc. Checkout transformer_net.py's header for details.
Python
110
star
10

serbian-llm-eval

Serbian LLM Eval.
Python
81
star
11

pytorch-naive-video-neural-style-transfer

Create naive (no temporal loss) NST for videos with person segmentation. Just place your videos in data/, run and you get your stylized and segmented videos.
Python
73
star
12

OpenGemini

Effort to open-source 10.5 trillion parameter Gemini model.
17
star
13

gordicaleksa

GitHub's new feature: repo with the same name as your GitHub name initialized with README.md will show on your landing page!
12
star
14

digital-image-processing

Projects I did for the Digital Image Processing course on my university
MATLAB
7
star
15

streamlit_playground

Simple Streamlit app.
Python
4
star
16

Open-NLLB-stopes

A library for preparing data for machine translation research (monolingual preprocessing, bitext mining, etc.) for the Open-NLLB effort.
Python
3
star
17

MachineLearningMicrosoftPetnica

Problems I solved for Microsoft ML summer camp in Petnica, Serbia
C++
3
star
18

competitive_programming

Contains algorithms and snippets I found useful when solving problems for TopCoder, Google Code Jam etc.
C++
2
star
19

slovenian-llm-eval

Slovenian LLM Eval.
Python
2
star
20

MicrosoftBubbleCup2018

My solutions for Bubble Cup 2018
C++
1
star
21

.dotfiles

Configuration files for my vim editor, bash etc.
Shell
1
star
22

GoogleCodeJam2018

My solutions for Google Code Jam 2018
C++
1
star