Implementation of some unbalanced loss for NLP task like focal_loss, dice_loss, DSC Loss, GHM Loss et.al and adversarial training like FGM, FGSM, PGD, FreeAT.
Loss Summary
Here is a loss implementation repository included unbalanced loss
Loss Name | paper | Notes |
---|---|---|
Weighted CE Loss | UNet Architectures in Multiplanar Volumetric Segmentation -- Validated on Three Knee MRI Cohorts | |
Focal Loss | Focal Loss for Dense Object Detection | |
Dice Loss | V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation | |
DSC Loss | Dice Loss for Data-imbalanced NLP Tasks | |
GHM Loss | Gradient Harmonized Single-stage Detector | |
Label Smoothing | When Does Label Smoothing Help? |
How to use?
You can find all the loss usage information in test_loss.py.
Here is a simple demo of usage:
import torch
from unbalanced_loss.focal_loss import MultiFocalLoss
batch_size, num_class = 64, 10
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')
logits = torch.rand(batch_size, num_class, requires_grad=True) # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size,)) # (batch_size, )
loss = Loss_Func(logits, targets)
loss.backward()
Adversarial Training Summary
Here is a Summary of Adversarial Training implementation.
you can find more details in adversarial_training/README.md
Adversarial Training | paper | Notes |
---|---|---|
FGM | Fast Gradient Method | |
FGSM | Fast Gradient Sign Method | |
PGD | Towards Deep Learning Models Resistant to Adversarial Attacks | |
FreeAT | Free Adversarial Training | |
FreeLB | Free Large Batch Adversarial Training |
How to use?
You can find a simple demo for bert classification in test_bert.py.
Here is a simple demo of usage:
You just need to rewrite train function according to input for your model in file PGD.py, then you can use adversarial
training like below.
import transformers
from model import bert_classification
from adversarial_training.PGD import PGD
batch_size, num_class = 64, 10
# model = your_model()
model = bert_classification()
AT_Model = PGD(model)
optimizer = transformers.AdamW(model.parameters(), lr=0.001)
# rewrite your train function in pgd.py
outputs, loss = AT_Model.train_bert(token, segment, mask, label, optimizer)
Adversarial Training Results
here are some results tested on THNews classification task based on bert.
you can find run the code as below:
cd scripts
sh run_at.sh
Adversarial Training | Time Cost(s/epoch ) | best_acc |
---|---|---|
Normal(not add attack) | 23.77 | 0.773 |
FGSM | 45.95 | 0.7936 |
FGM | 47.28 | 0.8008 |
PGD(k=3) | 87.50 | 0.7963 |
FreeAT(k=3) | 93.26 | 0.7896 |