• Stars
    star
    593
  • Rank 75,443 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A Pytorch Knowledge Distillation library for benchmarking and extending works in the domains of Knowledge Distillation, Pruning, and Quantization.

KD-Lib

A PyTorch model compression library containing easy-to-use methods for knowledge distillation, pruning, and quantization

Installation

https://github.com/SforAiDl/KD_Lib.git
cd KD_Lib
python setup.py install

From PyPI

pip install KD-Lib

Example usage

To implement the most basic version of knowledge distillation from Distilling the Knowledge in a Neural Network and plot loss curves:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import VanillaKD

# This part is where you define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

teacher_model = <your model>
student_model = <your model>

teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
student_optimizer = optim.SGD(student_model.parameters(), 0.01)

# Now, this is where KD_Lib comes into the picture

distiller = VanillaKD(teacher_model, student_model, train_loader, test_loader, 
                      teacher_optimizer, student_optimizer)  
distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)    # Train the teacher network
distiller.train_student(epochs=5, plot_losses=True, save_model=True)    # Train the student network
distiller.evaluate(teacher=False)                                       # Evaluate the student network
distiller.get_parameters()                                              # A utility function to get the number of 
                                                                        # parameters in the  teacher and the student network

To train a collection of 3 models in an online fashion using the framework in Deep Mutual Learning and log training details to Tensorboard:

import torch
import torch.optim as optim
from torchvision import datasets, transforms
from KD_Lib.KD import DML
from KD_Lib.models import ResNet18, ResNet50          # To use models packaged in KD_Lib

# Define your datasets, dataloaders, models and optimizers

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "mnist_data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

student_params = [4, 4, 4, 4, 4]
student_model_1 = ResNet50(student_params, 1, 10)
student_model_2 = ResNet18(student_params, 1, 10)

student_cohort = [student_model_1, student_model_2]

student_optimizer_1 = optim.SGD(student_model_1.parameters(), 0.01)
student_optimizer_2 = optim.SGD(student_model_2.parameters(), 0.01)

student_optimizers = [student_optimizer_1, student_optimizer_2]

# Now, this is where KD_Lib comes into the picture 

distiller = DML(student_cohort, train_loader, test_loader, student_optimizers, log=True, logdir="./logs")

distiller.train_students(epochs=5)
distiller.evaluate()
distiller.get_parameters()

Methods Implemented

Some benchmark results can be found in the logs file.

Paper / Method Link Repository (KD_Lib/)
Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531 KD/vision/vanilla
Improved Knowledge Distillation via Teacher Assistant https://arxiv.org/abs/1902.03393 KD/vision/TAKD
Relational Knowledge Distillation https://arxiv.org/abs/1904.05068 KD/vision/RKD
Distilling Knowledge from Noisy Teachers https://arxiv.org/abs/1610.09650 KD/vision/noisy
Paying More Attention To The Attention https://arxiv.org/abs/1612.03928 KD/vision/attention
Revisit Knowledge Distillation: a Teacher-free
Framework
https://arxiv.org/abs/1909.11723 KD/vision/teacher_free
Mean Teachers are Better Role Models https://arxiv.org/abs/1703.01780 KD/vision/mean_teacher
Knowledge Distillation via Route Constrained
Optimization
https://arxiv.org/abs/1904.09149 KD/vision/RCO
Born Again Neural Networks https://arxiv.org/abs/1805.04770 KD/vision/BANN
Preparing Lessons: Improve Knowledge Distillation
with Better Supervision
https://arxiv.org/abs/1911.07471 KD/vision/KA
Improving Generalization Robustness with Noisy
Collaboration in Knowledge Distillation
https://arxiv.org/abs/1910.05057 KD/vision/noisy
Distilling Task-Specific Knowledge from BERT into
Simple Neural Networks
https://arxiv.org/abs/1903.12136 KD/text/BERT2LSTM
Deep Mutual Learning https://arxiv.org/abs/1706.00384 KD/vision/DML
The Lottery Ticket Hypothesis: Finding Sparse,
Trainable Neural Networks
https://arxiv.org/abs/1803.03635 Pruning/lottery_tickets
Regularizing Class-wise Predictions via
Self-knowledge Distillation
https://arxiv.org/abs/2003.13964 KD/vision/CSDK

Please cite our pre-print if you find KD-Lib useful in any way :)

@misc{shah2020kdlib,
  title={KD-Lib: A PyTorch library for Knowledge Distillation, Pruning and Quantization}, 
  author={Het Shah and Avishree Khare and Neelay Shah and Khizir Siddiqui},
  year={2020},
  eprint={2011.14691},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

More Repositories

1

Neural-Voice-Cloning-With-Few-Samples

This repository has implementation for "Neural Voice Cloning With Few Samples"
Python
428
star
2

genrl

A PyTorch reinforcement learning library for generalizable and reproducible algorithm implementations with an aim to improve accessibility in RL
Python
402
star
3

vformer

A modular PyTorch library for vision transformer models
Python
163
star
4

Deep-Learning-TIP

Jupyter Notebook
26
star
5

Summer-Induction-Assignment-2021

Repository for SAiDL Summer 2021 Induction Assignment
21
star
6

paper-reading-group

Notes for papers presented during our paper reading sessions
20
star
7

SAiDL-Summer-2023-Induction-Assignment

18
star
8

Playground

A python library consisting of pipelines for visual analysis of different sports using Computer Vision and Deep Learning.
Python
18
star
9

CountCLIP

Jupyter Notebook
16
star
10

jeta

A Jax based meta learning library
Python
16
star
11

decepticonlp

Python Library for Robustness Monitoring and Adversarial Debugging of NLP models
Python
15
star
12

Summer-Induction-Assignment-2020

Repository for SAiDL Summer Assignment 2020
Python
14
star
13

SAiDL-Spring-2024-Induction-Assignment

13
star
14

SAiDL-Spring-2022-Induction-Assignment

Repository for SAiDL Spring 2022 Induction Assignment
12
star
15

neuroscience-ai-reading-course

Notes for the Neuroscience & AI Reading Course (SEM-I 2020-21) at BITS Pilani Goa Campus
12
star
16

saliency_estimation

Python library to estimate saliency
Python
8
star
17

SAiDL-Season-of-Code

6
star
18

twitter-sanity

A python tool to recommend relevant and important tweets from your Twitter feed.
Python
6
star
19

evis

A utility Python library for event-based vision
Python
5
star
20

NeurIPS2020

5
star
21

Winter-Assignment-2019

Winter Assignment 2019
2
star
22

Bootcamp

Python + ML Bootcamp
Jupyter Notebook
2
star
23

Winter-Assignment-2018

1
star
24

blogs

1
star