Recurrent Visual Attention
This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.
The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.
Model Description
In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.
- glimpse sensor: a retina that extracts a foveated glimpse
phi
around locationl
from an imagex
. It encodes the region aroundl
at a high-resolution but uses a progressively lower resolution for pixels further froml
, resulting in a compressed representation of the original imagex
. - glimpse network: a network that combines the "what" (
phi
) and the "where" (l
) into a glimpse feature vector wg_t
. - core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector
h_t
that gets updated at every time stept
. - location network: uses the internal state
h_t
of the core network to produce the location coordinatesl_t
for the next time step. - action network: after a fixed number of time steps, uses the internal state
h_t
of the core network to produce the final output classificationy
.
Results
I decided to tackle the 28x28
MNIST task with the RAM model containing 6 glimpses, of size 8x8
, with a scale factor of 1
.
Model | Validation Error | Test Error |
---|---|---|
6 8x8 | 1.1 | 1.21 |
I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub 1%
error. I'll be updating the table above with results for the 60x60
Translated MNIST, 60x60
Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.
Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.
With the Adam optimizer, paper accuracy can be reached in ~160 epochs.
Usage
The easiest way to start training your RAM variant is to edit the parameters in config.py
and run the following command:
python main.py
To resume training, run:
python main.py --resume=True
Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:
python main.py --is_train=False