Meta Pseudo Labels
This is an unofficial PyTorch implementation of Meta Pseudo Labels. The official Tensorflow implementation is here.
Results
CIFAR-10-4K | SVHN-1K | ImageNet-10% | |
---|---|---|---|
Paper (w/ finetune) | 96.11 Β± 0.07 | 98.01 Β± 0.07 | 73.89 |
This code (w/o finetune) | 96.01 | - | - |
This code (w/ finetune) | 96.08 | - | - |
Acc. curve | w/o finetune w/ finetune |
- | - |
- February 2022, Retested.
Usage
Train the model by 4000 labeled data of CIFAR-10 dataset:
python main.py \
--seed 2 \
--name cifar10-4K.2 \
--expand-labels \
--dataset cifar10 \
--num-classes 10 \
--num-labeled 4000 \
--total-steps 300000 \
--eval-step 1000 \
--randaug 2 16 \
--batch-size 128 \
--teacher_lr 0.05 \
--student_lr 0.05 \
--weight-decay 5e-4 \
--ema 0.995 \
--nesterov \
--mu 7 \
--label-smoothing 0.15 \
--temperature 0.7 \
--threshold 0.6 \
--lambda-u 8 \
--warmup-steps 5000 \
--uda-steps 5000 \
--student-wait-steps 3000 \
--teacher-dropout 0.2 \
--student-dropout 0.2 \
--finetune-epochs 625 \
--finetune-batch-size 512 \
--finetune-lr 3e-5 \
--finetune-weight-decay 0 \
--finetune-momentum 0.9 \
--amp
Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel:
python -m torch.distributed.launch --nproc_per_node 4 main.py \
--seed 2 \
--name cifar100-10K.2 \
--dataset cifar100 \
--num-classes 100 \
--num-labeled 10000 \
--expand-labels \
--total-steps 300000 \
--eval-step 1000 \
--randaug 2 16 \
--batch-size 128 \
--teacher_lr 0.05 \
--student_lr 0.05 \
--weight-decay 5e-4 \
--ema 0.995 \
--nesterov \
--mu 7 \
--label-smoothing 0.15 \
--temperature 0.7 \
--threshold 0.6 \
--lambda-u 8 \
--warmup-steps 5000 \
--uda-steps 5000 \
--student-wait-steps 3000 \
--teacher-dropout 0.2 \
--student-dropout 0.2 \
--finetune-epochs 250 \
--finetune-batch-size 512 \
--finetune-lr 3e-5 \
--finetune-weight-decay 0 \
--finetune-momentum 0.9 \
--amp
Monitoring training progress
tensorboard
tensorboard --logdir results
or
Use wandb
Requirements
- python 3.6+
- torch 1.7+
- torchvision 0.8+
- tensorboard
- wandb
- numpy
- tqdm
Citations
@misc{jd2021mpl,
author = {Jungdae Kim},
title = {PyTorch implementation of Meta Pseudo Labels},
year = {2021},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/kekmodel/MPL-pytorch}}
}