• Stars
    star
    2,244
  • Rank 20,522 (Top 0.5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 6 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A (PyTorch) imbalanced dataset sampler for oversampling low frequent classes and undersampling high frequent ones.

Imbalanced Dataset Sampler

license

Introduction

In many machine learning applications, we often come across datasets where some types of data may be seen more than other types. Take identification of rare diseases for example, there are probably more normal samples than disease ones. In these cases, we need to make sure that the trained model is not biased towards the class that has more data. As an example, consider a dataset where there are 5 disease images and 20 normal images. If the model predicts all images to be normal, its accuracy is 80%, and F1-score of such a model is 0.88. Therefore, the model has high tendency to be biased toward the โ€˜normalโ€™ class.

To solve this problem, a widely adopted technique is called resampling. It consists of removing samples from the majority class (under-sampling) and / or adding more examples from the minority class (over-sampling). Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch). The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfitting. In under-sampling, the simplest technique involves removing random records from the majority class, which can cause loss of information.

resampling

In this repo, we implement an easy-to-use PyTorch sampler ImbalancedDatasetSampler that is able to

  • rebalance the class distributions when sampling from the imbalanced dataset
  • estimate the sampling weights automatically
  • avoid creating a new balanced dataset
  • mitigate overfitting when it is used in conjunction with data augmentation techniques

Usage

For a simple start install the package via one of following ways:

pip install torchsampler

Simply pass an ImbalancedDatasetSampler for the parameter sampler when creating a DataLoader. For example:

from torchsampler import ImbalancedDatasetSampler

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    sampler=ImbalancedDatasetSampler(train_dataset),
    batch_size=args.batch_size,
    **kwargs
)

Then in each epoch, the loader will sample the entire dataset and weigh your samples inversely to your class appearing probability.

Example: Imbalanced MNIST Dataset

Distribution of classes in the imbalanced dataset:

With Imbalanced Dataset Sampler:

(left: test acc in each epoch; right: confusion matrix)

Without Imbalanced Dataset Sampler:

(left: test acc in each epoch; right: confusion matrix)

Note that there are significant improvements for minor classes such as 2 6 9, while the accuracy of the other classes is preserved.

Contributing

We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.

Licensing

MIT licensed.

More Repositories

1

deepo

Setup and customize deep learning environment in seconds.
Python
6,321
star
2

cpp-core-guidelines-cheatsheet

Cheatsheet for the C++ core guidelines, including a set of tried-and-true guidelines, rules, and best practices about coding in C++.
550
star
3

cropman

Face-aware image cropping.
Python
367
star
4

recursive-bf

A lightweight C++ library for recursive bilateral filtering [Yang, Qingxiong. "Recursive bilateral filtering". European Conference on Computer Vision, 2012].
C++
345
star
5

wavelet-rasterization

Python implementation of [Manson, Josiah, and Scott Schaefer. "Wavelet rasterization." Computer Graphics Forum. Vol. 30. No. 2. Blackwell Publishing Ltd, 2011].
Python
83
star
6

html5-svg-viewer

Zoomable and panable svg viewer
18
star
7

Caffe-mini

Zero dependents caffe for testing phase
C++
8
star
8

simgan

Python
5
star
9

industry-map

็”ฑGB/T 4754โ€”2017ๆ ‡ๅ‡†ๆ–‡ไปถ่งฃๆžๆๅ–็š„ไธญๅ›ฝๅ›ฝๆฐ‘็ปๆตŽ่กŒไธšๅˆ†็ฑปๅ›พ๏ผŒไปฅๅŠ่งฃๆžๆๅ–็š„่‡ชๅŠจๅŒ–่„šๆœฌ
Python
4
star
10

svg-view

lightweight svg viewer
C
3
star
11

Twitch

Realtime image stitching with two cameras
C++
3
star
12

iVec

interactive image vectorization
C++
2
star
13

google-reader-reader

reader of google reader
Python
2
star
14

A4Pose

Pose estimation for A4 paper
C++
2
star
15

pure-cornucopia

ready-to-compile version of Ilya's Cornucopia
C
2
star
16

WarpMan

Interactive image warping
C++
2
star
17

pyGIST

GIST implementation, written completely in Python.
Python
2
star
18

forestry-demo

HTML
2
star
19

image-deduplicator

Near duplicate image detection.
Python
2
star
20

super-resolution

C++
1
star
21

GithubPageTest

only for test
1
star
22

dip-final-evaluation

Python
1
star
23

OpenCapture

Capture engine for OpenCV but using Windows Media Foundation CaptureEngine instead (to achieve a good compatibility with various cameras)
C++
1
star
24

algorithm

C++
1
star
25

doubanbook

Python
1
star
26

agpy

Automatically exported from code.google.com/p/agpy
Python
1
star
27

dataset

1
star
28

mnn2mem

Convert mnn model into a C++ header
C++
1
star