Implementing Stand-Alone Self-Attention in Vision Models using Pytorch (13 Jun 2019)
- Stand-Alone Self-Attention in Vision Models paper
- Author:
- Prajit Ramachandran (Google Research, Brain Team)
- Niki Parmar (Google Research, Brain Team)
- Ashish Vaswani (Google Research, Brain Team)
- Irwan Bello (Google Research, Brain Team)
- Anselm Levskaya (Google Research, Brain Team)
- Jonathon Shlens (Google Research, Brain Team)
- Awesome :)
Method
- Attention Layer
- Relative Position Embedding
- Replacing Spatial Convolutions
- A 2 Γ 2 average pooling with stride 2 operation follows the attention layer whenever spatial downsampling is required. - This work applies the transform on the ResNet family of architectures. The proposed transform swaps the 3 Γ 3 spatial convolution with a self-attention layer as defined in Equation 3. - Replacing the Convolutional Stem
- The initial layers of a CNN, sometimes referred to as the stem, play a critical role in learning local features such as edges, which later layers use to identify global objects. - The stem performs self-attention within each 4 Γ 4 spatial block of the original image, followed by batch normalization and a 4 Γ 4 max pool operation.
Experiments
Setup
- Spatial extent: 7
- Attention heads: 8
- Layers:
- ResNet 26: [1, 2, 4, 1]
- ResNet 38: [2, 3, 5, 2]
- ResNet 50: [3, 4, 6, 3]
Datasets | Model | Accuracy | Parameters (My Model, Paper Model) |
---|---|---|---|
CIFAR-10 | ResNet 26 | 90.94% | 8.30M, - |
CIFAR-10 | Naive ResNet 26 | 94.29% | 8.74M |
CIFAR-10 | ResNet 26 + stem | 90.22% | 8.30M, - |
CIFAR-10 | ResNet 38 (WORK IN PROCESS) | 89.46% | 12.1M, - |
CIFAR-10 | Naive ResNet 38 | 94.93% | 15.0M |
CIFAR-10 | ResNet 50 (WORK IN PROCESS) | 16.0M, - | |
IMAGENET | ResNet 26 (WORK IN PROCESS) | 10.3M, 10.3M | |
IMAGENET | ResNet 38 (WORK IN PROCESS) | 14.1M, 14.1M | |
IMAGENET | ResNet 50 (WORK IN PROCESS) | 18.0M, 18.0M |
Usage
Requirements
- torch==1.0.1
Todo
- Experiments
- IMAGENET
- Review relative position embedding, attention stem
- Code Refactoring
Reference
- ResNet Pytorch CIFAR
- ResNet Pytorch
- Thank you :)