• This repository has been archived on 24/Sep/2023
  • Stars
    star
    468
  • Rank 93,154 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 7 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A PyTorch Implementation of "Recurrent Models of Visual Attention"

Recurrent Visual Attention

This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.

Drawing

Drawing

The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.

Model Description

In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.

Drawing

  • glimpse sensor: a retina that extracts a foveated glimpse phi around location l from an image x. It encodes the region around l at a high-resolution but uses a progressively lower resolution for pixels further from l, resulting in a compressed representation of the original image x.
  • glimpse network: a network that combines the "what" (phi) and the "where" (l) into a glimpse feature vector wg_t.
  • core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector h_t that gets updated at every time step t.
  • location network: uses the internal state h_t of the core network to produce the location coordinates l_t for the next time step.
  • action network: after a fixed number of time steps, uses the internal state h_t of the core network to produce the final output classification y.

Results

I decided to tackle the 28x28 MNIST task with the RAM model containing 6 glimpses, of size 8x8, with a scale factor of 1.

Model Validation Error Test Error
6 8x8 1.1 1.21

I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub 1% error. I'll be updating the table above with results for the 60x60 Translated MNIST, 60x60 Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.

Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.

Drawing

With the Adam optimizer, paper accuracy can be reached in ~160 epochs.

Usage

The easiest way to start training your RAM variant is to edit the parameters in config.py and run the following command:

python main.py

To resume training, run:

python main.py --resume=True

Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:

python main.py --is_train=False

References

More Repositories

1

spatial-transformer-network

A Tensorflow implementation of Spatial Transformer Networks.
Python
978
star
2

pytorch-goodies

PyTorch Boilerplate For Research
Python
601
star
3

torchnca

A PyTorch implementation of Neighbourhood Components Analysis.
Python
400
star
4

mjctrl

Minimal, clean, single-file implementations of common robotics controllers in MuJoCo.
Python
204
star
5

mink

Python inverse kinematics based on MuJoCo
Python
184
star
6

obj2mjcf

A CLI for processing composite Wavefront OBJ files for use in MuJoCo.
Python
155
star
7

torchkit

Research boilerplate for PyTorch.
Python
150
star
8

mujoco_scanned_objects

MuJoCo Models for Google's Scanned Objects Dataset
145
star
9

clip_playground

An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities
Jupyter Notebook
144
star
10

tsne-viz

Python Wrapper for t-SNE Visualization
Python
126
star
11

ibc

A PyTorch implementation of Implicit Behavioral Cloning
Python
93
star
12

form2fit

[ICRA 2020] Train generalizable policies for kit assembly with self-supervised dense correspondence learning.
Python
82
star
13

blog-code

My blog's code repository.
Jupyter Notebook
76
star
14

learn-linalg

Learning some numerical linear algebra.
Python
70
star
15

dexterity

Software and tasks for dexterous multi-fingered hand manipulation, powered by MuJoCo
Python
59
star
16

x-magical

[CoRL 2021] A robotics benchmark for cross-embodiment imitation.
Python
58
star
17

mjc_viewer

A browser-based 3D viewer for MuJoCo
Python
55
star
18

torchsdf-fusion

Benchmarking PyTorch variants of TSDF fusion.
Python
47
star
19

robopianist-rl

RL code for training piano-playing policies for RoboPianist.
Python
42
star
20

mujoco_tips_and_tricks

32
star
21

walle

My robotics research toolkit.
Python
22
star
22

mujoco_cube

A 3x3x3 puzzle cube model for MuJoCo.
Python
21
star
23

coffee

Infrastructure for PyBullet research
Python
20
star
24

robopianist-demo

C
20
star
25

learn-ransac

Learning about RANSAC.
Python
19
star
26

dm_env_wrappers

Standalone library of frequently-used wrappers for dm_env environments.
Python
18
star
27

root-locus

Python implementation of the Root Locus method.
Python
17
star
28

nanorl

A tiny reinforcement learning codebase for continuous control, built on top of JAX.
Python
12
star
29

software

My open-source software contributions.
9
star
30

kinetics

Python script to mine the Kinetics dataset.
Python
6
star
31

cloneformer

BC with Transformers
Python
5
star
32

mujoco_utils

Python
5
star
33

learn-blur

Learning about various image blurring techniques.
Python
3
star
34

pymenagerie

Composer classes for MuJoCo Menagerie models.
Python
3
star
35

learn-volumetric-fusion

Learning about volumetric fusion.
Python
2
star