• Stars
    star
    113
  • Rank 310,115 (Top 7 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[CVPR 2021] Code for "Augmentation Strategies for Learning with Noisy Labels".

Augmentation-for-LNL

PWC

Code for Augmentation Strategies for Learning with Noisy Labels (CVPR 2021).

Authors: Kento Nishi*, Yi Ding*, Alex Rich, Tobias Hรถllerer [*: equal contribution]

Abstract Imperfect labels are ubiquitous in real-world datasets. Several recent successful methods for training deep neural networks (DNNs) robust to label noise have used two primary techniques: filtering samples based on loss during a warm-up phase to curate an initial set of cleanly labeled samples, and using the output of a network as a pseudo-label for subsequent loss calculations. In this paper, we evaluate different augmentation strategies for algorithms tackling the "learning with noisy labels" problem. We propose and examine multiple augmentation strategies and evaluate them using synthetic datasets based on CIFAR-10 and CIFAR-100, as well as on the real-world dataset Clothing1M. Due to several commonalities in these algorithms, we find that using one set of augmentations for loss modeling tasks and another set for learning is the most effective, improving results on the state-of-the-art and other previous methods. Furthermore, we find that applying augmentation during the warm-up period can negatively impact the loss convergence behavior of correctly versus incorrectly labeled samples. We introduce this augmentation strategy to the state-of-the-art technique and demonstrate that we can improve performance across all evaluated noise levels. In particular, we improve accuracy on the CIFAR-10 benchmark at 90% symmetric noise by more than 15% in absolute accuracy, and we also improve performance on the real-world dataset Clothing1M.

Banner

View on arXiv / View PDF / Download Paper Source / Download Source Code

Thumbnail
Watch CVPR Video

Benchmarks

All Benchmarks

Key

Annotation Meaning
Small Worse or equivalent to previous state-of-the-art
Normal Better than previous state-of-the-art
Bold Best in task/category

CIFAR-10

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym 40% asym
Runtime-W (Vanilla DivideMix) Highest 96.100% 94.600% 93.200% 76.000% 93.400%
Last 10 95.700% 94.400% 92.900% 75.400% 92.100%
Raw Highest 85.940% 27.580%
Last 10 83.230% 23.915%
Expansion.Weak Highest 90.860% 31.220%
Last 10 89.948% 10.000%
Expansion.Strong Highest 90.560% 35.100%
Last 10 89.514% 34.228%
AugDesc-WW Highest 96.270% 36.050%
Last 10 96.084% 23.503%
Runtime-S Highest 96.540% 70.470%
Last 10 96.327% 70.223%
AugDesc-SS Highest 96.470% 81.770%
Last 10 96.193% 81.540%
AugDesc-WS.RandAug.n1m6 Highest 96.280% 89.750%
Last 10 96.006% 89.629%
AugDesc-WS.SAW Highest 96.350% 95.640% 93.720% 35.330% 94.390%
Last 10 96.138% 95.417% 93.563% 10.000% 94.078%
AugDesc-WS (WAW) Highest 96.330% 95.360% 93.770% 91.880% 94.640%
Last 10 96.168% 95.134% 93.641% 91.760% 94.258%

CIFAR-100

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym
Runtime-W (Vanilla DivideMix) Highest 77.300% 74.600% 60.200% 31.500%
Last 10 76.900% 74.200% 59.600% 31.000%
Raw Highest 52.240% 7.990%
Last 10 39.176% 2.979%
Expansion.Weak Highest 57.110% 7.300%
Last 10 53.288% 2.223%
Expansion.Strong Highest 55.150% 7.540%
Last 10 54.369% 3.242%
AugDesc-WW Highest 78.900% 30.330%
Last 10 78.437% 29.876%
Runtime-S Highest 79.890% 40.520%
Last 10 79.395% 40.343%
AugDesc-SS Highest 79.790% 38.850%
Last 10 79.511% 38.553%
AugDesc-WS.RandAug.n1m6 Highest 78.060% 36.890%
Last 10 77.826% 36.672%
AugDesc-WS.SAW Highest 79.610% 77.640% 61.830% 17.570%
Last 10 79.464% 77.522% 61.632% 15.050%
AugDesc-WS (WAW) Highest 79.500% 77.240% 66.360% 41.200%
Last 10 79.216% 77.010% 66.046% 40.895%

Clothing1M

Model Accuracy
Runtime-W (Vanilla DivideMix) 74.760%
AugDesc-WS (WAW) 74.720%
AugDesc-WS.SAW 75.109%
Summary Metrics

CIFAR-10

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym 40% asym
SOTA Highest 96.100% 94.600% 93.200% 76.000% 93.400%
Last 10 95.700% 94.400% 92.900% 75.400% 92.100%
Ours Highest 96.540% 95.640% 93.770% 91.880% 94.640%
Last 10 96.327% 95.417% 93.641% 91.760% 94.258%

CIFAR-100

Model Metric Noise Type/Ratio
20% sym 50% sym 80% sym 90% sym
SOTA Highest 77.300% 74.600% 60.200% 31.500%
Last 10 76.900% 74.200% 59.600% 31.000%
Ours Highest 79.890% 77.640% 66.360% 41.200%
Last 10 79.511% 77.522% 66.046% 40.895%

Clothing1M

Model Accuracy
SOTA 74.760%
Ours 75.109%

Training Locally

The source code is heavily reliant on CUDA. Please make sure that you have the newest version of Pytorch and a compatible version of CUDA installed. Using older versions may exhibit inconsistent performance.

Download Pytorch / Download CUDA

Other requirements are included in requirements.txt.

Reproducibility

At particularly high noise ratios (ex. 90% on CIFAR-10), results may vary across training runs. We are aware of this issue, and are exploring ways to yield more consistent results. We will publish any findings (consistently performant configurations, improved procedures, etc.) both in this repository and in continuations of this work.

All training configurations and parameters are controlled via the presets.json file. Configurations can contain infinite subconfigurations, and settings specified in subconfigurations always override the parent.

To train locally, first add your local machine to the presets.json:

{
    // ... inside the root scope
    "machines": { // list of machines
        "localPC": { // name for your local PC, can be anything
            "checkpoint_path": "./localPC_checkpoints"
        }
    },
    "configs": {
        "c10": { // cifar-10 dataset
            "machines": { // list of machines
                "localPC": { // local PC name
                    "data_path": "/path/to/your/dataset"
                    // path to dataset (python) downloaded from:
                    // https://www.cs.toronto.edu/~kriz/cifar.html
                }
                // ... keep all other machines unchanged
            }
            // ... keep all other config values unchanged
        }
        // ... keep all other configs unchanged
    }
    // ... keep all other global values unchanged
}

A "preset" is a specific configuration branch. For example, if you would like to run train_cifar.py with the preset root -> c100 -> 90sym -> AugDesc-WS on your machine named localPC, you can run the following command:

python train_cifar.py --preset c100.90sym.AugDesc-WS --machine localPC

The script will begin training the preset specified by the --preset argument. Progress will be saved in the appropriate directory in your specified checkpoint_path. Additionally, if the --machine flag is ommitted, the training script will look for the dataset in the data_path inherited from parent configurations.

Here are some abbreviations used in our presets.json:

Abbreviation Meaning
c10 CIFAR-10
c100 CIFAR-100
c1m Clothing1M
sym Symmetric Noise
asym Asymmetric Noise
SAW Strongly Augmented Warmup
WAW Weakly Augmented Warmup
RandAug RandAugment

Citations

Please cite the following:

@InProceedings{Nishi_2021_CVPR,
    author    = {Nishi, Kento and Ding, Yi and Rich, Alex and {H{\"o}llerer, Tobias},
    title     = {Augmentation Strategies for Learning With Noisy Labels},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {8022-8031}
}

Extras

Extra bits of unsanitized code for plotting, training, etc. can be found in the Aug-for-LNL-Extras repository.

Additional Info

This repository is a fork of the official DivideMix implementation.

More Repositories

1

torch-pitch-shift

Pitch-shift audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.
Python
131
star
2

awesome-all-you-need-papers

A list of all "all you need" papers. Updated daily using the arXiv API.
Python
71
star
3

torch-time-stretch

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.
Python
36
star
4

pogify

Listen to music with your stream chat without getting DMCA-striked.
JavaScript
12
star
5

JTR-CVPR-2024

[CVPR 2024] Joint-Task Regularization for Partially Labeled Multi-Task Learning
Python
10
star
6

Stack-Trace-Visualizer

An interactive graphical interface to visualize the Java call stack.
Java
8
star
7

iframe-translator

Translate text for free in the browser with iframe shenanigans
TypeScript
8
star
8

svelte-typescript-template

An opinionated template for my Svelte projects.
JavaScript
8
star
9

pytorch-template

An opinionated template for my PyTorch repositories.
Python
7
star
10

torch-mutable-modules

Use in-place and assignment operations on PyTorch module parameters with support for autograd.
Python
7
star
11

WSL-GUI

A tool to enable graphical applications in Windows Subsystem for Linux.
Batchfile
7
star
12

Aug-for-LNL-Extras

Extra bits of unsanitized code for plotting, training, etc. related to our CVPR 2021 paper "Augmentation Strategies for Learning with Noisy Labels".
Jupyter Notebook
6
star
13

PythonPP

[Python++] A robust Java-style OOP system for Python, with support for statics, encapsulation, and inheritance.
Python
5
star
14

runestone-submission-downloader

Download all Runestone student submissions at once as a zip! Made for Mr. Mark Kwong at Lynbrook (@psmaker).
TypeScript
5
star
15

USACO-Solutions

Solutions to algorithmic problems from http://www.usaco.org.
C++
4
star
16

Bitmap-Numbers-Dataset

A dataset of labeled handwritten numbers for machine learning.
4
star
17

kentonishi.github.io

Kento Nishi's Github Pages website.
Svelte
4
star
18

KentoNishi

My GitHub profile page, powered by GitHub Readme Stats.
4
star
19

LiveTL

This repository exists for legacy support purposes. https://github.com/LiveTL/LiveTL/
3
star
20

Eulers-Method-Python

Euler's Method approximation in Python.
Python
3
star
21

ML-Number-Detection

A slow and unoptimized number detection neural network.
C++
3
star
22

exio

A framework-agnostic UI library which extends native HTML elements. Primarily intended for use in my own side projects.
TypeScript
3
star
23

Algorithms-and-Data-Structures

Based on "Data Structures and Algorithms in Java" by Robert Lafore.
C++
2
star
24

Spotify-Controls

An extension for Spotify shortcuts and controls.
CSS
2
star
25

Jetson-Nano-Playground

A playground of undocumented and experimental projects for Jetson Nano.
C++
2
star
26

BlankSort

The Python package for the BlankSort keyword extraction algorithm.
Python
2
star
27

BlankSort-Prototypes

A Novel Unsupervised Approach to Keyword Extraction
Jupyter Notebook
2
star
28

Simple-Neural-Network

Simple neural network implementation in C++.
C++
1
star
29

TensorFlow-2.0-Tutorials

A collection of TensorFlow programs, following the official TensorFlow 2.0 Tutorials.
Python
1
star
30

TensorFlow-Playground

A playground of undocumented and experimental TensorFlow projects.
Python
1
star
31

TensorFlow-v1-Tutorials

A collection of TensorFlow programs, following the official TensorFlow Tutorial.
Python
1
star
32

USACO-Grader

An automatic grader for problems from http://www.usaco.org.
Python
1
star
33

GHP-TensorFlow.js

A simple client based TensorFlow project in JavaScript, powered by Jekyll and GitHub pages.
JavaScript
1
star
34

ML-Playground

A playground of experimental machine learning scripts.
Jupyter Notebook
1
star