FMix
This repository contains the official implementation of the paper 'FMix: Enhancing Mixed Sampled Data Augmentation'
ArXiv • Papers With Code • About • Experiments • Implementations • Pre-trained Models
Dive in with our example notebook in Colab!
About
FMix is a variant of MixUp, CutMix, etc. introduced in our paper 'FMix: Enhancing Mixed Sampled Data Augmentation'. It uses masks sampled from Fourier space to mix training examples. Take a look at our example notebook in colab which shows how you can generate masks in two dimensions
and in three!
Experiments
Core Experiments
Shell scripts for our core experiments can be found in the experiments folder. For example,
bash cifar_experiment cifar10 resnet fmix ./data
will train a PreAct-ResNet18 on CIFAR-10 with FMix. More information can be found at the start of each of the shell files.
Additional Experiments
All additional classification experiments can be run via trainer.py
Analyses
For Grad-CAM, take a look at the Grad-CAM notebook in colab.
For the other analyses, have a look in the analysis folder.
Implementations
The core implementation of FMix
uses numpy
and can be found in fmix.py
. We provide bindings for this in PyTorch (with Torchbearer or PyTorch-Lightning) and Tensorflow.
Torchbearer
The FMix
callback in torchbearer_implementation.py
can be added directly to your torchbearer code:
from implementations.torchbearer_implementation import FMix
fmix = FMix()
trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])
See an example in test_torchbearer.py
.
PyTorch-Lightning
For PyTorch-Lightning, we provide a class, FMix
in lightning.py
that can be used in your LightningModule
:
from implementations.lightning import FMix
class CoolSystem(pl.LightningModule):
def __init__(self):
...
self.fmix = FMix()
def training_step(self, batch, batch_nb):
x, y = batch
x = self.fmix(x)
x = self.forward(x)
loss = self.fmix.loss(x, y)
return {'loss': loss}
See an example in test_lightning.py
.
Tensorflow
For Tensorflow, we provide a class, FMix
in tensorflow_implementation.py
that can be used in your tensorflow code:
from implementations.tensorflow_implementation import FMix
fmix = FMix()
def loss(model, x, y, training=True):
x = fmix(x)
y_ = model(x, training=training)
return tf.reduce_mean(fmix.loss(y_, y))
See an example in test_tensorflow.py
.
Pre-trained Models
We provide pre-trained models via torch.hub
(more coming soon). To use them, run
import torch
model = torch.hub.load('ecs-vlc/FMix:master', ARCHITECTURE, pretrained=True)
where ARCHITECTURE
is one of the following:
CIFAR-10
PreAct-ResNet-18
Configuration | ARCHITECTURE |
Accuracy |
---|---|---|
Baseline | 'preact_resnet18_cifar10_baseline' |
-------- |
+ MixUp | 'preact_resnet18_cifar10_mixup' |
-------- |
+ FMix | 'preact_resnet18_cifar10_fmix' |
-------- |
+ Mixup + FMix | 'preact_resnet18_cifar10_fmixplusmixup' |
-------- |
PyramidNet-200
Configuration | ARCHITECTURE |
Accuracy |
---|---|---|
Baseline | 'pyramidnet_cifar10_baseline' |
98.31 |
+ MixUp | 'pyramidnet_cifar10_mixup' |
97.92 |
+ FMix | 'pyramidnet_cifar10_fmix' |
98.64 |
ImageNet
ResNet-101
Configuration | ARCHITECTURE |
Accuracy (Top-1) |
---|---|---|
Baseline | 'renset101_imagenet_baseline' |
76.51 |
+ MixUp | 'renset101_imagenet_mixup' |
76.27 |
+ FMix | 'renset101_imagenet_fmix' |
76.72 |