• Stars
    star
    124
  • Rank 288,207 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created about 6 years ago
  • Updated about 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

1-bit Wide ResNet

PyTorch implementation of training 1-bit Wide ResNets from this paper:

Training wide residual networks for deployment using a single bit for each weight by Mark D. McDonnell at ICLR 2018

https://openreview.net/forum?id=rytNfI1AZ

https://arxiv.org/abs/1802.08530

The idea is very simple but surprisingly effective for training ResNets with binary weights. Here is the proposed weight parameterization as PyTorch autograd function:

class ForwardSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w):
        return math.sqrt(2. / (w.shape[1] * w.shape[2] * w.shape[3])) * w.sign()

    @staticmethod
    def backward(ctx, g):
        return g

On forward, we take sign of the weights and scale it by He-init constant. On backward, we propagate gradient without changes. WRN-20-10 trained with such parameterization is only slightly off from it's full precision variant, here is what I got myself with this code on CIFAR-100:

network accuracy (5 runs mean +- std) checkpoint (Mb)
WRN-20-10 80.5 +- 0.24 205 Mb
WRN-20-10-1bit 80.0 +- 0.26 3.5 Mb

Details

Here are the differences with WRN code https://github.com/szagoruyko/wide-residual-networks:

  • BatchNorm has no affine weight and bias parameters
  • First layer has 16 * width channels
  • Last fc layer is removed in favor of 1x1 conv + F.avg_pool2d
  • Downsample is done by F.avg_pool2d + torch.cat instead of strided conv
  • SGD with cosine annealing and warm restarts

I used PyTorch 0.4.1 and Python 3.6 to run the code.

Reproduce WRN-20-10 with 1-bit training on CIFAR-100:

python main.py --binarize --save ./logs/WRN-20-10-1bit_$RANDOM --width 10 --dataset CIFAR100

Convergence plot (train error in dash):

download

I've also put 3.5 Mb checkpoint with binary weights packed with np.packbits, and a very short script to evaluate it:

python evaluate_packed.py --checkpoint wrn20-10-1bit-packed.pth.tar --width 10 --dataset CIFAR100

S3 url to checkpoint: https://s3.amazonaws.com/modelzoo-networks/wrn20-10-1bit-packed.pth.tar

More Repositories

1

pytorchviz

A small package to create visualizations of PyTorch execution graphs
Jupyter Notebook
3,180
star
2

attention-transfer

Improving Convolutional Networks via Attention Transfer (ICLR 2017)
Jupyter Notebook
1,439
star
3

wide-residual-networks

3.8% and 18.3% on CIFAR-10 and CIFAR-100
Lua
1,297
star
4

diracnets

Training Very Deep Neural Networks Without Skip-Connections
Jupyter Notebook
586
star
5

functional-zoo

PyTorch and Tensorflow functional model definitions
Jupyter Notebook
586
star
6

loadcaffe

Load Caffe networks in Torch7
Protocol Buffer
494
star
7

cvpr15deepcompare

Code and models for "Learning to Compare Image Patches via Convolutional Neural Networks"
C++
467
star
8

pyinn

CuPy fused PyTorch neural networks ops
Python
274
star
9

cifar.torch

92.45% on CIFAR-10 in Torch
Lua
174
star
10

torch-opencv-demos

Torch7+OpenCV+ConvNets
Lua
167
star
11

imagine-nn

IMAGINE torch neural network routines
Lua
109
star
12

torch-caffe-binding

Use Caffe in Torch7
C++
64
star
13

imagenet-validation.torch

Fast and easy testing of imagenet models
Lua
49
star
14

neural-style-autograd

autograd version of https://github.com/jcjohnson/neural-style
Lua
44
star
15

cunnproduction

easy embeddable Torch7 networks
C++
35
star
16

nnpack.torch

Torch FFI-bindings for NNPACK
Lua
30
star
17

iterm.torch

Display images directly in iTerm2
Lua
28
star
18

openai-gemm.pytorch

PyTorch bindings for openai-gemm
Python
20
star
19

fastrcnn-models.torch

Fast-RCNN models in Torch-7 format
18
star
20

cutorch-rtc

lua apply function for cutorch
Lua
17
star
21

idiap-tutorials

Jupyter Notebook
16
star
22

functional-style-transfer

minimal implementation of style transfer
Jupyter Notebook
10
star
23

nvrtc.torch

Torch7 bindings for CUDA NVRTC (runtime compilation) library
Lua
9
star
24

imi-demos

live convolutional neural networks demos
Python
9
star
25

cunn-rtc

Runtime compiled Torch cunn modules
Lua
8
star
26

clipp.torch

Torch interface to OpenCLIPP
C++
6
star
27

examples

Python
5
star
28

libclsvm

OpenCL optimized SVM library
C++
2
star
29

infimnist.torch

Torch7 InfiMNIST ffi binding
C
1
star