Value Iteration Networks
VIN:A quick thank you
A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:
- @avivt (Paper Author, MATLAB implementation)
- @zuoxingdong (Tensorflow implementation, Pytorch implementation)
- @TheAbhiKumar (Tensorflow implementation)
- @onlytailei (Pytorch implementation)
Why another VIN implementation?
- The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
- This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
- Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.
Installation
This repository requires following packages:
- SciPy >= 0.19.0
- Python >= 2.7 (if using Python 3.x: python3-tk should be installed)
- Numpy >= 1.12.1
- Matplotlib >= 2.0.0
- PyTorch >= 0.1.11
Use pip
to install the necessary dependencies:
pip install -U -r requirements.txt
Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.
How to train
8x8 gridworld
python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
16x16 gridworld
python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128
28x28 gridworld
python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128
Flags:
datafile
: The path to the data files.imsize
: The size of input images. One of: [8, 16, 28]lr
: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]epochs
: Number of epochs to train. Default: 30k
: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]l_i
: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.l_h
: Number of channels in first convolutional layer. Default: 150, described in paper.l_q
: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.batch_size
: Batch size. Default: 128
How to test / visualize paths (requires training first)
8x8 gridworld
python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10
16x16 gridworld
python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20
28x28 gridworld
python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
To visualize the optimal and predicted paths simply pass:
--plot
Flags:
weights
: Path to trained weights.imsize
: The size of input images. One of: [8, 16, 28]plot
: If supplied, the optimal and predicted paths will be plottedk
: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]l_i
: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.l_h
: Number of channels in first convolutional layer. Default: 150, described in paper.l_q
: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
Results
Gridworld | Sample One | Sample Two |
---|---|---|
8x8 | ||
16x16 | ||
28x28 |
Datasets
Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.
Dataset size | 8x8 | 16x16 | 28x28 |
---|---|---|---|
Train set | 81337 | 456309 | 1529584 |
Test set | 13846 | 77203 | 251755 |
The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py
script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)
Performance: Success Rate
This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).
Success Rate | 8x8 | 16x16 | 28x28 |
---|---|---|---|
PyTorch | 99.69% | 96.99% | 91.07% |
Performance: Test Accuracy
NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.
Test Accuracy | 8x8 | 16x16 | 28x28 |
---|---|---|---|
PyTorch | 99.83% | 94.84% | 88.54% |