• Stars
    star
    142
  • Rank 258,495 (Top 6 %)
  • Language
    Python
  • License
    Other
  • Created over 5 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask

Authors

Hattie Zhou, Janice Lan, Rosanne Liu, Jason Yosinski

Introduction

This codebase implements the experiments in Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask. This paper performs various ablation studies to shine light into the Lottery Tickets (LT) phenomenon observed by Frankle & Carbin in The Lottery Ticket Hypothesis: Finding Small, Trainable Neural Networks.

@inproceedings{zhou_2019_dlt
  title={Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask},
  author={Zhou, Hattie and Lan, Janice and Liu, Rosanne and Yosinski, Jason},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

For more on this project, see the Uber Eng Blog post.

Codebase structure

  • data/download_mnist.py, data/download_cifar10.py downloads MNIST/CIFAR10 data and splits it into train, val, and test, and saves them in the data folder as h5 files
  • get_weight_init.py computes various mask criteria
  • masked_layers.py defines new layer classes with masking options
  • masked_networks.py defines new layers and networks used in training Supermasks
  • network_builders.py defines the four network architecture evaluated in the paper (FC, Conv2, Conv4, Conv6)
  • train.py trains original unmasked networks
  • train_lottery.py reads in initial and final weights from a previously trained model, calculates the mask, and train a lottery style network
  • train_supermask trains a supermask directly using Bernoulli sampling
  • get_init_loss_train_lottery.py derives masks and calculates the initial accuracy of the masked network for various pruning percentages and mask criteria. Note that this uses a one-shot approach rather than an iterative approach.

This codebase uses the GitResultsManager package to keep track of experiments. See: https://github.com/yosinski/GitResultsManager

Example commands for running experiments

The following commands provide examples for running experiments in Deconstructing Lottery Tickets.

Train the original, unpruned network

  • Train a FC network (300-100-10) on MNIST: ./print_train_command.sh iter fc test 0 t

Alternative mask criteria experiments (using FC on MNIST and large final as an example)

  • Perform iterative LT training for a FC network on MNIST using large final mask criterion: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask none t

Mask-1 experiments

  • Randomly reinitialize weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reinit t

  • Randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reshuffle t

  • Convert the initial values of weights to a signed constant before randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask rand_signed_constant t

  • For versions that maintain the same sign, see signed_reinit, signed_reshuffle, and signed_constant.

Mask-0 experiments

  • Freeze pruned weights at initial values: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init none t

  • Freeze pruned weights that increased in magnitude at initial values: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_mask none t

  • Initialize weights that decreased in magnitude at 0, and freeze pruned weights at initial value: ./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_all none t

Supermask experiments

  • Evaluate the initial test accuracy of all alternative mask criteria: python get_init_loss_train_lottery.py --output_dir ./results/iter_lot_fc_orig/test_seed_0/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_lot --seed 0 --opt adam --lr 0.0012 --exp none --layer_cutoff 4,6 --prune_base 0.8,0.9 --prune_power 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24

  • Train a Supermask directly: python train_supermask.py --output_dir ./results/iter_lot_fc_orig/learned_supermasks/run1/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_mask --opt sgd --lr 100 --num_epochs 2000 --print_every 220 --eval_every 220 --log_every 220 --save_weights --save_every 22000

More Repositories

1

deep-neuroevolution

Deep Neuroevolution
Python
1,630
star
2

PPLM

Plug and Play Language Model implementation. Allows to steer topic and attributes of GPT-2 models.
Python
1,125
star
3

UPSNet

UPSNet: A Unified Panoptic Segmentation Network
Python
639
star
4

go-explore

Code for Go-Explore: a New Approach for Hard-Exploration Problems
Python
553
star
5

PyTorch-NEAT

Python
526
star
6

LaneGCN

[ECCV2020 Oral] Learning Lane Graph Representations for Motion Forecasting
Python
502
star
7

sbnet

Sparse Blocks Networks
Python
430
star
8

differentiable-plasticity

Implementations of the algorithms described in Differentiable plasticity: training plastic networks with gradient descent, a research paper from Uber AI Labs.
Python
394
star
9

DeepPruner

DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch (ICCV 2019)
Python
351
star
10

parallax

Tool for interactive embeddings visualization
Python
284
star
11

learning-to-reweight-examples

Code for paper "Learning to Reweight Examples for Robust Deep Learning"
Python
269
star
12

jpeg2dct

C++
251
star
13

poet

Paired Open-Ended Trailblazer (POET) and Enhanced POET
Python
235
star
14

intrinsic-dimension

Jupyter Notebook
220
star
15

CoordConv

Python
208
star
16

atari-model-zoo

A binary release of trained deep reinforcement learning models trained in the Atari machine learning benchmark, and a software release that enables easy visualization and analysis of models, and comparison across training algorithms.
Jupyter Notebook
201
star
17

ape-x

This repo replicates the results Horgan et al obtained in "Distributed Prioritized Experience Replay"
Python
188
star
18

EvoGrad

Python
178
star
19

TuRBO

Python
178
star
20

safemutations

safemutations
C++
143
star
21

permute-quantize-finetune

Using ideas from product quantization for state-of-the-art neural network compression.
Python
143
star
22

CRISP

Python
131
star
23

metropolis-hastings-gans

Python
112
star
24

GTN

Python
75
star
25

backpropamine

Train self-modifying neural networks with neuromodulated plasticity
Python
73
star
26

loss-change-allocation

Python
61
star
27

MARVIN

Uber's Multi-Agent Routing Value Iteration Network
Python
57
star
28

GOCC

Go
51
star
29

Synthetic-Petri-Dish

Python
42
star
30

RxThreadEffectChecker

Static checker for Rx Threading Effects, based on the Checker Framework
Java
35
star
31

Map-Elites-Evolutionary

Map-Elites based on Evolution Strategies
Python
31
star
32

D3G

Estimating Q(s,s') with Deep Deterministic Dynamics Gradients
Python
29
star
33

java-dependency-validator

Dependency validator detects runtime compatibility issues at build time
Java
23
star
34

vargp

Variational Auto-Regressive Gaussian Processes for Continual Learning
Python
20
star
35

normative-uncertainty

Python
15
star
36

Evolvability-ES

Python
14
star
37

brezel

Starlark
8
star
38

dispatch-optim

Constrainted based optimization
Python
8
star
39

ga-world-models

Python
7
star
40

FSDM

Code tor the SIGDIAL 2019 paper Flexibly-Structured Model for Task-Oriented Dialogues. It implements a deep learning end-to-end differentiable dialogue system model
Python
7
star
41

rl-controller-verification

Quadcopter Verification
Python
6
star
42

go-context-propagate

Go
4
star
43

last-diff-analyzer

A multi-language tool for checking semantic equivalence for code
Go
2
star
44

presto-HDFS-read-data

A dump of some of our Presto logs, for use as part of ongoing Presto/HDFS research and presentations.
2
star
45

xplane-bazel-docker

Bazel Xplane
C++
1
star
46

tailr

TAILR
Python
1
star