Knowledge Distillation Toolkit
Knowledge Distillation Training
Code for "Shrinking Bigfoot: Reducing wav2vec 2.0 footprint"
This toolkit allows you to compress a machine learning model using knowledge distillation. To use this toolkit, you need to provide a teacher model, a student model, data loaders for training and validation, and an inference pipeline. This toolkit is based on PyTorch and PyTorch Lightning , so teacher and student models need to be PyTorch neural network modules, and data loaders need to be PyTorch data loaders.
Demo
We have provided two demos which use this toolkit and compress machine learning models. In these demos, we show how to create student and teacher model, inference pipeline, training and validation data loaders, and pass them into the knowledge distillation toolkit.
Compress wav2vec 2.0: this notebook
Usage
Define inference pipeline
class InferencePipeline:
def __init__(self):
# Set up if necessary
def run_inference_pipeline(self, model, data_loader):
# Get the accuracy of the model on a validation dataset
return {"inference_result": accuracy}
See here for a concrete example.
Define student model
import torch.nn as nn
class StudentModel(nn.Module):
def forward(self, ):
# Forward pass of the student model
See here for a concrete example.
Define teacher model
import torch.nn as nn
class TeacherModel(nn.Module):
def forward(self, ):
# Forward pass of the teacher model
See here for a concrete example.
Put everything together and start knowledge distillation
import torch
inference_pipeline = InferencePipeline()
student_model = StudentModel()
teacher_model = TeacherModel()
train_data_loader = torch.utils.data.DataLoader(train_dataset)
val_data_loaders = {"a validation dataset": torch.utils.data.DataLoader(val_dataset)}
KD_example = KnowledgeDistillationTraining(train_data_loader = train_data_loader,
val_data_loaders = val_data_loaders,
inference_pipeline = inference_pipeline,
student_model = student_model,
teacher_model = teacher_model)
KD_example.start_kd_training()
Start knowledge distillation training!
To start knowledge distillation training, you need to first instantiate the KnowledgeDistillationTraining class, then call the start_kd_training method.
In the table below, we show arguments that the constructor of KnowledgeDistillationTraining
class takes in.
Argument Name | Type | Explanation | Default |
---|---|---|---|
teacher_model |
torch.nn.Module |
A teacher model. | None |
student_model |
torch.nn.Module |
A student model. | None |
train_data_loader |
torch.utils.data.DataLoader |
Data loader for the training data set. | None |
val_data_loaders |
dict |
A dictionary which could contain multiple validation data loaders. The key should be the data loader's name and value is a data loader. Note that the data loader should be an instance of torch.utils.data.DataLoader . |
None |
inference_pipeline |
object |
A python class that returns the validation result. See below for more information on this class. | None |
num_gpu_used |
int |
Number of GPUs used for training. | 0 |
max_epoch |
int |
Number of training epochs. | 10 |
optimize_method |
str |
Optimization method used to train the student model. Could be one of ["adam", "sgd", "adam_wav2vec2.0", "adam_distilBert", "adamW_distilBert"]. | "adam" |
scheduler_method |
str |
Learning rate scheduler. Could be one of ["", "linear_decay_with_warm_up", "cosine_anneal"]. No learning rate scheduling if setting to "". | "" |
learning_rate |
float |
Learning rate for knowledge distillation traininig. | 0.0001 |
num_lr_warm_up_epoch |
int |
Number of epochs to warm up (increase) the learning rate. Set to 0 if not warming up the learning rate. | 0 |
final_loss_coeff_dict |
dict |
A dictionary which contains coefficients that should be multiplied with the loss. See below for more information. | {"kd_loss":1} |
log_to_comet |
bool |
Set to True if logging experiment results to comet.ml. If debugging, set this to False. | False |
comet_info_path |
str |
Path to a txt file which contains api key, project name and work space at comet.ml. | "" |
comet_exp_name |
str |
Experiment name on comet.ml. | "" |
temperature |
int |
Temperature for calculating the knowledge distillation loss. | 1 |
seed |
int |
Seed value for the experiment. | 32 |
track_grad_norm |
int |
The norm to use when calculating the gradient for tracking. | 2 |
accumulate_grad_batches |
int |
Number of gradient accumulation steps. | 1 |
accelerator |
str /None |
Accelerators for PyTorch Lightning. See here for details. | None |
num_nodes |
int |
Number of compute nodes. | 1 |
precision |
int |
16 bit or 32 bit training. See here for details. | 16 |
deterministic |
bool |
deterministic flag in PyTorch lightning. |
True |
resume_from_checkpoint |
str |
Path to a previous check point where the current experiment should resume from. | "" |
logging_param |
dict |
A dictionary which contains parameters that should be saved to comet.ml. | None |
How does inference pipeline work?
This toolkit uses inference pipeline to test the student model. The inference_pipeline
class needs to implement a method run_inference_pipeline
. The purpose of this method is to get the performance of the student model on a validation dataset.
We walk you through how we created an inference pipeline in the code below. We pass model
and data_loader
to run_inference_pipeline
. The model
is a student_model
, and data_loader
is a validation data loader. You should have these two arguments in hands when you are using this toolkit, because you need them to instantiate the KnowledgeDistillationTraining
class. Inside run_inference_pipeline
, we take every data sample from data_loader
, then pass it to the model
. For every data sample, we calculate an accuracy based on the student model's prediction and ground truth. Finally, we calculate the overall accuracy
and return it as a dictionary. In the returned dictionary, inference_result
should match to the overall accuracy.
class inference_pipeline:
def __init__(self):
# Constructor method is optional.
def run_inference_pipeline(self, model, data_loader):
accuracy = 0
model.eval()
with torch.no_grad():
for i, data in enumerate(data_loader):
X, y = data[0].to(self.device), data[1].to(self.device)
outputs = model(X)
predicted = torch.max(outputs["prob"], 1)[1]
accuracy += predicted.eq(y.view_as(predicted)).sum().item()
accuracy = accuracy / len(data_loader.dataset)
return {"inference_result": accuracy}
The code above is just an example and you can create inference pipeline in whatever way you want. Just remember two rules:
-
The
inference_pipeline
class only needs to implementrun_inference_pipeline
.run_inference_pipeline
tests a student model on a validation dataset. -
run_inference_pipeline
should return a dictionary, e.g. {"inference_result": a numerical value that measures the performance of a student model on a validation dataset.}
How does loss function work?
We need to calculate a loss function when we do knowledge distillation training. We always have a knowledge distillation loss, which is the KL divergence between teacher and student model's probability distribution. It is also common to add a supervised training loss into the loss function. For example, if the student model is an image classification network, the supervised training loss could be cross entropy loss. The final loss function would combine different losses, so we can calculate the gradient using the final loss and update the student model (Obviously, we will not do this manually and PyTorch will do this automatically for us).
How does this toolkit combine different losses (e.g. knowledge distillation loss, supervised training loss, etc.) and form the final loss? We take coefficients from final_loss_coeff_dict
and multiply them with losses in final_loss_components
. Both final_loss_coeff_dict
and final_loss_components
are dictionaries, and keys in these two dictionaries must match so we know how to match a coefficient with a loss. For instance, final_loss_components
always contains "kd_loss"
(the key for knolwdege distillation loss), so there must be a key "kd_loss"
in final_loss_coeff_dict
. Users need to provide final_loss_coeff_dict
when using this toolkit, but not final_loss_components
, which we form internally. If you also want to add a supervised training loss (the student loss) to the final loss, return it at the end of forward pass of the student model. See here for an example.