• Stars
    star
    978
  • Rank 46,823 (Top 1.0 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 8 years ago
  • Updated over 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A Tensorflow implementation of Spatial Transformer Networks.

Spatial Transformer Networks

This is a Tensorflow implementation of Spatial Transformer Networks by Max Jaderberg, Karen Simonyan, Andrew Zisserman and Koray Kavukcuoglu, accompanying by two-part blog tutorial series.

Spatial Transformer Networks (STN) is a differentiable module that can be inserted anywhere in ConvNet architecture to increase its geometric invariance. It effectively gives the network the ability to spatially transform feature maps at no extra data or supervision cost.

Installation

Install the stn package using:

pip3 install stn

Then, you can call the STN layer as follows:

from stn import spatial_transformer_network as transformer

out = transformer(input_feature_map, theta, out_dims)

Parameters

  • input_feature_map: the output of the layer preceding the localization network. If the STN layer is the first layer of the network, then this corresponds to the input images. Shape should be (B, H, W, C).
  • theta: this is the output of the localization network. Shape should be (B, 6)
  • out_dims: desired (H, W) of the output feature map. Useful for upsampling or downsampling. If not specified, then output dimensions will be equal to input_feature_map dimensions.

Background Information

The STN is composed of 3 elements.

  • localization network: takes the feature map as input and outputs the parameters of the affine transformation that should be applied to that feature map.

  • grid generator: generates a grid of (x,y) coordinates using the parameters of the affine transformation that correspond to a set of points where the input feature map should be sampled to produce the transformed output feature map.

  • bilinear sampler: takes as input the input feature map and the grid generated by the grid generator and produces the output feature map using bilinear interpolation.

The affine transformation is specified through the transformation matrix A

It can be constrained to one of attention by writing it in the form

where the parameters s, t_x and t_y can be regressed to allow cropping, translation, and isotropic scaling.

For a more in-depth explanation of STNs, read the two part blog post: part1 and part2.

Explore

Run the Sanity Check to get a feel of how the spatial transformer can be plugged into any existing code. For example, here's the result of a 45 degree rotation:

Drawing Drawing

Usage Note

You must define a localization network right before using this layer. The localization network is usually a ConvNet or a FC-net that has 6 output nodes (the 6 parameters of the affine transformation).

It is good practice to initialize the localization network to the identity transform before starting the training process. Here's a small sample code for illustration purposes.

# params
n_fc = 6
B, H, W, C = (2, 200, 200, 3)

# identity transform
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32').flatten()

# input placeholder
x = tf.placeholder(tf.float32, [B, H, W, C])

# localization network
W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([B, H*W*C]), W_fc1) + b_fc1

# spatial transformer layer
h_trans = transformer(x, h_fc1)

Attribution

More Repositories

1

pytorch-goodies

PyTorch Boilerplate For Research
Python
601
star
2

recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
Python
468
star
3

torchnca

A PyTorch implementation of Neighbourhood Components Analysis.
Python
400
star
4

mjctrl

Minimal, clean, single-file implementations of common robotics controllers in MuJoCo.
Python
204
star
5

mink

Python inverse kinematics based on MuJoCo
Python
184
star
6

obj2mjcf

A CLI for processing composite Wavefront OBJ files for use in MuJoCo.
Python
155
star
7

torchkit

Research boilerplate for PyTorch.
Python
150
star
8

mujoco_scanned_objects

MuJoCo Models for Google's Scanned Objects Dataset
145
star
9

clip_playground

An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities
Jupyter Notebook
144
star
10

tsne-viz

Python Wrapper for t-SNE Visualization
Python
126
star
11

ibc

A PyTorch implementation of Implicit Behavioral Cloning
Python
93
star
12

form2fit

[ICRA 2020] Train generalizable policies for kit assembly with self-supervised dense correspondence learning.
Python
82
star
13

blog-code

My blog's code repository.
Jupyter Notebook
76
star
14

learn-linalg

Learning some numerical linear algebra.
Python
70
star
15

dexterity

Software and tasks for dexterous multi-fingered hand manipulation, powered by MuJoCo
Python
59
star
16

x-magical

[CoRL 2021] A robotics benchmark for cross-embodiment imitation.
Python
58
star
17

mjc_viewer

A browser-based 3D viewer for MuJoCo
Python
55
star
18

torchsdf-fusion

Benchmarking PyTorch variants of TSDF fusion.
Python
47
star
19

robopianist-rl

RL code for training piano-playing policies for RoboPianist.
Python
42
star
20

mujoco_tips_and_tricks

32
star
21

walle

My robotics research toolkit.
Python
22
star
22

mujoco_cube

A 3x3x3 puzzle cube model for MuJoCo.
Python
21
star
23

coffee

Infrastructure for PyBullet research
Python
20
star
24

robopianist-demo

C
20
star
25

learn-ransac

Learning about RANSAC.
Python
19
star
26

dm_env_wrappers

Standalone library of frequently-used wrappers for dm_env environments.
Python
18
star
27

root-locus

Python implementation of the Root Locus method.
Python
17
star
28

nanorl

A tiny reinforcement learning codebase for continuous control, built on top of JAX.
Python
12
star
29

software

My open-source software contributions.
9
star
30

kinetics

Python script to mine the Kinetics dataset.
Python
6
star
31

cloneformer

BC with Transformers
Python
5
star
32

mujoco_utils

Python
5
star
33

learn-blur

Learning about various image blurring techniques.
Python
3
star
34

pymenagerie

Composer classes for MuJoCo Menagerie models.
Python
3
star
35

learn-volumetric-fusion

Learning about volumetric fusion.
Python
2
star