Deep Reinforcement Learning with
pytorch & visdom
- Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
-
Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes):
-
Sample loggings while training a DQN agent on CartPole (we use
WARNING
as the logging level currently to get rid of theINFO
printouts from visdom):
[WARNING ] (MainProcess) <===================================>
[WARNING ] (MainProcess) bash$: python -m visdom.server
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900
[WARNING ] (MainProcess) <===================================> DQN
[WARNING ] (MainProcess) <-----------------------------------> Env
[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123
[INFO ] (MainProcess) Making new env: CartPole-v0
[WARNING ] (MainProcess) Action Space: [0, 1]
[WARNING ] (MainProcess) State Space: 4
[WARNING ] (MainProcess) <-----------------------------------> Model
[WARNING ] (MainProcess) MlpModel (
(fc1): Linear (4 -> 16)
(rl1): ReLU ()
(fc2): Linear (16 -> 16)
(rl2): ReLU ()
(fc3): Linear (16 -> 16)
(rl3): ReLU ()
(fc4): Linear (16 -> 2)
)
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
[WARNING ] (MainProcess) <===================================> Training ...
[WARNING ] (MainProcess) Validation Data @ Step: 501
[WARNING ] (MainProcess) Start Training @ Step: 501
[WARNING ] (MainProcess) Reporting @ Step: 2500 | Elapsed Time: 5.32397913933
[WARNING ] (MainProcess) Training Stats: epsilon: 0.972
[WARNING ] (MainProcess) Training Stats: total_reward: 2500.0
[WARNING ] (MainProcess) Training Stats: avg_reward: 21.7391304348
[WARNING ] (MainProcess) Training Stats: nepisodes: 115
[WARNING ] (MainProcess) Training Stats: nepisodes_solved: 114
[WARNING ] (MainProcess) Training Stats: repisodes_solved: 0.991304347826
[WARNING ] (MainProcess) Evaluating @ Step: 2500
[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539
[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488
[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107
[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106
[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607
[WARNING ] (MainProcess) Saving Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...
[WARNING ] (MainProcess) Saved Model @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.
[WARNING ] (MainProcess) Resume Training @ Step: 2500
...
What is included?
This repo currently contains the following agents:
- Deep Q Learning (DQN) [1], [2]
- Double DQN [3]
- Dueling network DQN (Dueling DQN) [4]
- Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
- Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]
Work in progress:
- Testing ACER
Future Plans:
Code structure & Naming conventions:
NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable.
./utils/factory.py
We suggest the users refer to
./utils/factory.py
, where we list all the integratedEnv
,Model
,Memory
,Agent
intoDict
's. All of those four core classes are implemented in./core/
. The factory pattern in./utils/factory.py
makes the code super clean, as no matter what type ofAgent
you want to train, or which type ofEnv
you want to train on, all you need to do is to simply modify some parameters in./utils/options.py
, then the./main.py
will do it all (NOTE: this./main.py
file never needs to be modified).
- namings
To make the code more clean and readable, we name the variables using the following pattern (mainly in inherited
Agent
's):
*_vb
:torch.autograd.Variable
's or a list of such objects*_ts
:torch.Tensor
's or a list of such objects- otherwise: normal python datatypes
Dependencies
- Python 2.7
- PyTorch >=v0.2.0
- Visdom
- OpenAI Gym >=v0.9.0 (for lower versoins, just need to change into the available games, e.g. change PongDeterministic-v4 to PongDeterministic-v3)
- mujoco-py (Optional: for training continuous version of a3c)
How to run:
You only need to modify some parameters in ./utils/options.py
to train a new configuration.
- Configure your training in
./utils/options.py
:
line 14
: add an entry intoCONFIGS
to define your training (agent_type
,env_type
,game
,model_type
,memory_type
)line 33
: choose the entry you just addedline 29-30
: fill in your machine/cluster ID (MACHINE
) and timestamp (TIMESTAMP
) to define your training signature (MACHINE_TIMESTAMP
), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth
&./logs/MACHINE_TIMESTAMP.log
respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash:python -m visdom.server &
, then open this address in your browser:http://localhost:8097/env/MACHINE_TIMESTAMP
)line 32
: to train a model, setmode=1
(training visualization will be underhttp://localhost:8097/env/MACHINE_TIMESTAMP
); to test the model of this current training, all you need to do is to setmode=2
(testing visualization will be underhttp://localhost:8097/env/MACHINE_TIMESTAMP_test
).
- Run:
python main.py
Bonus Scripts :)
We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot)
plot.sh
(e.g., plot from log file:logs/machine1_17080801.log
)
./plot.sh machine1 17080801
- the generated figures will be saved into
figs/machine1_17080801/
plot_compare.sh
(e.g., compare log files:logs/machine1_17080801.log
,logs/machine2_17080802.log
)
./plot.sh 00 machine1 17080801 machine2 17080802
- the generated figures will be saved into
figs/compare_00/
- the color coding will be in the order of:
red green blue magenta yellow cyan
Repos we referred to during the development of this repo:
- matthiasplappert/keras-rl
- transedward/pytorch-dqn
- ikostrikov/pytorch-a3c
- onlytailei/A3C-PyTorch
- Kaixhin/ACER
- And a private implementation of A3C from @stokasto
Citation
If you find this library useful and would like to cite it, the following would be appropriate:
@misc{pytorch-rl,
author = {Zhang, Jingwei and Tai, Lei},
title = {jingweiz/pytorch-rl},
url = {https://github.com/jingweiz/pytorch-rl},
year = {2017}
}