• Stars
    star
    321
  • Rank 130,752 (Top 3 %)
  • Language
    Python
  • Created about 5 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

This repository contains a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks" by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.

Lottery Ticket Hypothesis in Pytorch

Made With python 3.7 Maintenance Open Source Love svg1

This repository contains a Pytorch implementation of the paper The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.

Requirements

pip3 install -r requirements.txt

How to run the code ?

Using datasets/architectures included with this repository :

python3 main.py --prune_type=lt --arch_type=fc1 --dataset=mnist --prune_percent=10 --prune_iterations=35
  • --prune_type : Type of pruning
    • Options : lt - Lottery Ticket Hypothesis, reinit - Random reinitialization
    • Default : lt
  • --arch_type : Type of architecture
    • Options : fc1 - Simple fully connected network, lenet5 - LeNet5, AlexNet - AlexNet, resnet18 - Resnet18, vgg16 - VGG16
    • Default : fc1
  • --dataset : Choice of dataset
    • Options : mnist, fashionmnist, cifar10, cifar100
    • Default : mnist
  • --prune_percent : Percentage of weight to be pruned after each cycle.
    • Default : 10
  • --prune_iterations : Number of cycle of pruning that should be done.
    • Default : 35
  • --lr : Learning rate
    • Default : 1.2e-3
  • --batch_size : Batch size
    • Default : 60
  • --end_iter : Number of Epochs
    • Default : 100
  • --print_freq : Frequency for printing accuracy and loss
    • Default : 1
  • --valid_freq : Frequency for Validation
    • Default : 1
  • --gpu : Decide Which GPU the program should use
    • Default : 0

Using datasets/architectures that are not included with this repository :

  • Adding a new architecture :
    • For example, if you want to add an architecture named new_model with mnist dataset compatibility.
      • Go to /archs/mnist/ directory and create a file new_model.py.
      • Now paste your Pytorch compatible model inside new_model.py.
      • IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your new_model.py matches with the corresponding dataset that you are adding (in this case, it is mnist).
      • Now open main.py and go to line 36 and look for the comment # Data Loader. Now find your corresponding dataset (in this case, mnist) and add new_model at the end of the line from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet.
      • Now go to line 82 and add the following to it :
         elif args.arch_type == "new_model":
         	model = new_model.new_model_name().to(device)
        
        Here, new_model_name() is the name of the model that you have given inside new_model.py.
  • Adding a new dataset :
    • For example, if you want to add a dataset named new_dataset with fc1 architecture compatibility.
      • Go to /archs and create a directory named new_dataset.
      • Now go to /archs/new_dataset/and add a file namedfc1.py` or copy paste it from existing dataset folder.
      • IMPORTANT : Make sure the input size, number of classes, number of channels, batch size in your new_model.py matches with the corresponding dataset that you are adding (in this case, it is new_dataset).
      • Now open main.py and goto line 58 and add the following to it :
         elif args.dataset == "cifar100":
         	traindataset = datasets.new_dataset('../data', train=True, download=True, transform=transform)
         	testdataset = datasets.new_dataset('../data', train=False, transform=transform)from archs.new_dataset import fc1
        
        Note that as of now, you can only add dataset that are natively available in Pytorch.

How to combine the plots of various prune_type ?

  • Go to combine_plots.py and add/remove the datasets/archs who's combined plot you want to generate (Assuming that you have already executed the main.py code for those dataset/archs and produced the weights).
  • Run python3 combine_plots.py.
  • Go to /plots/lt/combined_plots/ to see the graphs.

Kindly raise an issue if you have any problem with the instructions.

Datasets and Architectures that were already tested

fc1 LeNet5 AlexNet VGG16 Resnet18
MNIST ✔️ ✔️ ✔️ ✔️ ✔️
CIFAR10 ✔️ ✔️ ✔️ ✔️ ✔️
FashionMNIST ✔️ ✔️ ✔️ ✔️ ✔️
CIFAR100 ✔️ ✔️ ✔️ ✔️ ✔️

Repository Structure

Lottery-Ticket-Hypothesis-in-Pytorch
├── archs
│   ├── cifar10
│   │   ├── AlexNet.py
│   │   ├── densenet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   ├── cifar100
│   │   ├── AlexNet.py
│   │   ├── fc1.py
│   │   ├── LeNet5.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   └── mnist
│       ├── AlexNet.py
│       ├── fc1.py
│       ├── LeNet5.py
│       ├── resnet.py
│       └── vgg.py
├── combine_plots.py
├── dumps
├── main.py
├── plots
├── README.md
├── requirements.txt
├── saves
└── utils.py

Interesting papers that are related to Lottery Ticket Hypothesis which I enjoyed

Acknowledgement

Parts of code were borrowed from ktkth5.

Issue / Want to Contribute ? :

Open a new issue or do a pull request incase you are facing any difficulty with the code base or if you want to contribute to it.

forthebadge

Buy Me A Coffee

More Repositories

1

Intrusion-Detection-Systems

This is the repo of the research paper, "Evaluating Shallow and Deep Neural Networks for Network Intrusion Detection Systems in Cyber Security".
Python
239
star
2

TailCalibX

Pytorch implementation of Feature Generation for Long-Tail Classification by Rahul Vigneswaran, Marc T Law, Vineeth N Balasubramaniam and Makarand Tapaswi
Jupyter Notebook
38
star
3

Class-Balanced-Distillation-for-Long-Tailed-Visual-Recognition.pytorch

Un-offical PyTorch Implementation of "Class-Balanced Distillation for Long-Tailed Visual Recognition" paper.
Python
16
star
4

longtail-buzz

🐝 Explore Trending Long-Tail Papers at CVPR and ICCV
TypeScript
10
star
5

tsne-plotter

This is Matlab script for plotting 2 Dimensional and 3 Dimensional t-Distributed Stochastic Neighbor Embedding (t-SNE).
MATLAB
8
star
6

Dynamic-Mode-Decomposition-based-feature-for-Image-Classification

This repo consists of all the codes and dataset of the research paper, "Dynamic Mode Decomposition based feature for Image Classification".
3
star
7

Data-Driven-Computing-in-Elasticity-via-Chebyshev-Approximation

This is the repo of the research paper, "Data-driven computing in elasticity via Chebyshev Approximation".
MATLAB
3
star
8

Machine-Learning-Resources

Resource to get started with machine learning
2
star
9

CourseWork-CSE-MS-IITH-2023-2025

Contains everything that I did as part of my course work at IIT Hyderabad as a CSE Masters student
Jupyter Notebook
2
star
10

blog

1
star
11

rahulvigneswaran.github.io_old

HTML
1
star
12

Weird-Deep-Learning-Metrics

📚 A collection of all the Deep Learning Metrics that I came across which are not accuracy/loss.
Python
1
star
13

sharplot

Python
1
star
14

temp

HTML
1
star
15

rahulvigneswaran1.github.io

JavaScript
1
star
16

sample.github.io

HTML
1
star
17

100DaysOfDeepLearning

Jupyter Notebook
1
star
18

Convex-Optimization-Learning-Materials

This repo is created for during the learning of Convex Optimization
Jupyter Notebook
1
star
19

single-layer-perceptron

1
star
20

Long-tail-multiple-experts

Collection of papers that use multiple experts to solve long tail classification.
1
star