Implementation of the DRAW network architecture
This repository contains a reimplementation of the Deep Recurrent Attentive Writer (DRAW) network architecture introduced by K. Gregor, I. Danihelka, A. Graves and D. Wierstra. The original paper can be found at
http://arxiv.org/pdf/1502.04623
Dependencies
- Blocks follow the install instructions. This will install all the other dependencies for you (Theano, Fuel, etc.).
- Theano
- Fuel
- picklable_itertools
Draw currently works with the "cutting-edge development version". But since the API is subject to change, you might consider installing this known to be supported version:
You also need to install
- Bokeh 0.8.1+
- ipdb
- ImageMagick
Data
You need to set the location of your data directory:
export FUEL_DATA_PATH=/home/user/data
fuel-download
and fuel-convert
are used to obtain and convert training datasets. E.g. for binarized MNIST
cd $FUEL_DATA_PATH
fuel-download binarized_mnist
fuel-convert binarized_mnist
or similarly for SVHN
cd $FUEL_DATA_PATH
fuel-download svhn -d . 2
fuel-convert svhn -d . 2
Training with attention
To train a model with a 2x2 read and a 5x5 write attention window run
cd draw
./train-draw.py --dataset=bmnist --attention=2,5 --niter=64 --lr=3e-4 --epochs=100
On Amazon g2xlarge it takes more than 40min for Theano's compilation to end and training to start. If you enable the bokeh-server, once training starts you can track its live plotting. It will take about 2 days to train the model.
After each epoch it will save the following files:
- a pickle of the model
- a pickle of the log
- sampled output image for that epoch
- animation of sampled output
Generating animations
To generate sampled output including an animation run
python sample.py svhn_model.pkl --channels 3 --size 32
Note that in order to load a model and to generate samples all dependencies are needed. This unfortunately also this includes the GPU because python cannot unpickle CudaNdarray objects without it. This is a known problem that we don't yet a have general solution to.
SVHN
To train a model on SVHN
python train-draw.py --name=my_svhn --dataset=svhn2 \
--attention=5,5 --niter=32 --lr=3e-4 --epochs=100 \
--enc-dim 512 --dec-dim 512
After 100-200 epochs, the model above achieved a test_nll_bound
of 1825.82.
Log
Run
python plot-kl.py [pickle-of-log]
to create a visualization of the KL divergence potted over inference iterations and epochs. E.g:
Testing
Run
nosetests -v tests
to execute the testsuite. Run
cd draw
./attention.py
to test the attention windowing code on some image. It will open three windows: A window displaying the original input image, a window displaying some extracted, downsampled content (testing the read-operation), and a window showing the upsampled content (matching the input size) after the write operation.