• Stars
    star
    321
  • Rank 130,752 (Top 3 %)
  • Language
    Python
  • License
    Other
  • Created over 7 years ago
  • Updated about 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for experiments regarding importance sampling for training neural networks

Importance Sampling

This python package provides a library that accelerates the training of arbitrary neural networks created with Keras using importance sampling.

# Keras imports

from importance_sampling.training import ImportanceTraining

x_train, y_train, x_val, y_val = load_data()
model = create_keras_model()
model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

ImportanceTraining(model).fit(
    x_train, y_train,
    batch_size=32,
    epochs=10,
    verbose=1,
    validation_data=(x_val, y_val)
)

model.evaluate(x_val, y_val)

Importance sampling for Deep Learning is an active research field and this library is undergoing development so your mileage may vary.

Relevant Research

Ours

  • Not All Samples Are Created Equal: Deep Learning with Importance Sampling [preprint]
  • Biased Importance Sampling for Deep Neural Network Training [preprint]

By others

  • Stochastic optimization with importance sampling for regularized loss minimization [pdf]
  • Variance reduction in SGD by distributed importance sampling [pdf]

Dependencies & Installation

Normally if you already have a functional Keras installation you just need to pip install keras-importance-sampling.

  • Keras > 2
  • A Keras backend among Tensorflow, Theano and CNTK
  • blinker
  • numpy
  • matplotlib, seaborn, scikit-learn are optional (used by the plot scripts)

Documentation

The module has a dedicated documentation site but you can also read the source code and the examples to get an idea of how the library should be used and extended.

Examples

In the examples folder you can find some Keras examples that have been edited to use importance sampling.

Code examples

In this section we will showcase part of the API that can be used to train neural networks with importance sampling.

# Import what is needed to build the Keras model
from keras import backend as K
from keras.layers import Dense, Activation, Flatten
from keras.models import Sequential

# Import a toy dataset and the importance training
from importance_sampling.datasets import MNIST
from importance_sampling.training import ImportanceTraining


def create_nn():
    """Build a simple fully connected NN"""
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(40, activation="tanh"),
        Dense(40, activation="tanh"),
        Dense(10),
        Activation("softmax") # Needs to be separate to automatically
                              # get the preactivation outputs
    ])

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


if __name__ == "__main__":
    # Load the data
    dataset = MNIST()
    x_train, y_train = dataset.train_data[:]
    x_test, y_test = dataset.test_data[:]

    # Create the NN and keep the initial weights
    model = create_nn()
    weights = model.get_weights()

    # Train with uniform sampling
    K.set_value(model.optimizer.lr, 0.01)
    model.fit(
        x_train, y_train,
        batch_size=64, epochs=10,
        validation_data=(x_test, y_test)
    )

    # Train with importance sampling
    model.set_weights(weights)
    K.set_value(model.optimizer.lr, 0.01)
    ImportanceTraining(model).fit(
        x_train, y_train,
        batch_size=64, epochs=2,
        validation_data=(x_test, y_test)
    )

Using the script

The following terminal commands train a small VGG-like network to ~0.65% error on MNIST (the numbers are from a CPU). .. code:

$ # Train a small cnn with mnist for 500 mini-batches using importance
$ # sampling with bias to achieve ~ 0.65% error (on the CPU).
$ time ./importance_sampling.py \
>   small_cnn \
>   oracle-gnorm \
>   model \
>   predicted \
>   mnist \
>   /tmp/is \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 500 --validate_every 500
real    1m41.985s
user    8m14.400s
sys     0m35.900s
$
$ # And with uniform sampling to achieve ~ 0.9% error.
$ time ./importance_sampling.py \
>   small_cnn \
>   oracle-loss \
>   uniform \
>   unweighted \
>   mnist \
>   /tmp/uniform \
>   --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
>   --train_for 3000 --validate_every 3000
real    9m23.971s
user    47m32.600s
sys     3m4.188s

More Repositories

1

fast-transformers

Pytorch library for fast transformer implementations
Python
1,622
star
2

bob

Bob is a free signal-processing and machine learning toolbox originally developed by the Biometrics group at Idiap Research Institute, in Switzerland.
266
star
3

ESLAM

Python
202
star
4

fullgrad-saliency

Full-gradient saliency maps
Python
201
star
5

multicamera-calibration

Multi-Camera Calibration Suite
Python
179
star
6

GeoNeRF

Generalizing NeRF with Geometry Priors
Python
117
star
7

attention-sampling

This Python package enables the training and inference of deep learning models for very large data, such as megapixel images, using attention-sampling
Python
98
star
8

acoustic-simulator

Implementation of audio degradation processes
Python
95
star
9

mser

Linear time Maximally Stable Extremal Regions implementation
C++
95
star
10

kaldi-ivector

Extension to Kaldi implementing the standard i-vector hyperparameter estimation and i-vector extraction procedure
C++
88
star
11

mhan

Multilingual hierarchical attention networks toolkit
Python
78
star
12

pkwrap

A pytorch wrapper for LF-MMI training and parallel training in Kaldi
Python
73
star
13

gafro

An efficient c++ library targeting robotics applications using geometric algebra
C++
69
star
14

HAN_NMT

Document-Level Neural Machine Translation with Hierarchical Attention Networks
JavaScript
68
star
15

g2g-transformer

Pytorch implementation of β€œRecursive Non-Autoregressive Graph-to-Graph Transformer for Dependency Parsing with Iterative Refinement”
Python
61
star
16

juicer

Juicer is a Weighted Finite State Transducer (WFST) based decoder for Automatic Speech Recognition (ASR).
C++
60
star
17

facereclib

Compare your face recognition algorithm to baseline algorithms
57
star
18

sigma-gpt

Οƒ-GPT: A New Approach to Autoregressive Models
Python
53
star
19

model-uncertainty-for-adaptation

Code paper Uncertainty Reduction for Uncertainty Reduction for Model Adaptation in Semantic Segmentation at CVPR 2021
Python
49
star
20

eakmeans

Implementation of fast exact k-means algorithms
C++
47
star
21

atco2-corpus

A Corpus for Research on Robust Automatic Speech Recognition and Natural Language Understanding of Air Traffic Control Communications
Python
45
star
22

ssp

Speech Signal Processing - a small collection of routines in Python to do signal processing
Python
44
star
23

psfestimation

Code for the PyTorch implementation of "Spatially-Variant CNN-based Point Spread Function Estimation for Blind Deconvolution and Depth Estimation in Optical Microscopy", IEEE Transactions on Image Processing, 2020.
Python
34
star
24

w2v2-air-traffic

Python
34
star
25

potr

Python
32
star
26

residual_pose

Residual Pose: A Decoupled Approach for Depth-based 3D Human Pose Estimation
Python
32
star
27

CNN_QbE_STD

Implementation of the work presented in "CNN based Query by Example Spoken Term Detection"
Python
31
star
28

nnsslm

Neural Network based Sound Source Localization Models
Python
30
star
29

semiblindpsfdeconv

Code for "Semi-Blind Spatially-Variant Deconvolution in Optical Microscopy with Local Point Spread Function Estimation By Use Of Convolutional Neural Networks" ICIP 2018
Python
26
star
30

IBDiarization

C++ Implementation of the Information Bottleneck System
C++
23
star
31

gile

A generalized input-label embedding for text classification
Python
23
star
32

IdiapTTS

A Python-based modular toolbox for building Deep Neural Network models (using PyTorch) for statistical parametric speech synthesis
Python
23
star
33

HMMGradients.jl

Enables computing the gradient of the parameters of Hidden Markov Models (HMMs)
Julia
22
star
34

inv-tn

A bunch of scripts exploiting several tools to perform inverse text normalization (ITN)
Shell
21
star
35

deepfocus

Pytorch implementation of "DeepFocus: a Few-Shot Microscope Slide Auto-Focus using a Sample Invariant CNN-based Sharpness Function"
Python
20
star
36

multimodal_gaze_target_prediction

This repo provides the training and testing code for our paper "A Modular Multimodal Architecture for Gaze Target Prediction: Application to Privacy-Sensitive Settings" published at the GAZE workshop at CVPR 2022
Python
20
star
37

hypermixing

PyTorch implementation for HyperMixing, a linear-time token-mixing technique used in HyperMixer architecture
Python
19
star
38

sparch

PyTorch based toolkit for developing spiking neural networks (SNNs) by training and testing them on speech command recognition tasks
Python
18
star
39

zff_vad

Unsupervised Voice Activity Detection by Modeling Source and System Information using Zero Frequency Filtering
Python
18
star
40

contextual-biasing-on-gpus

Implementation of the contextual biasing for ASR decoding on GPUs without lattice generation. The code supports submission to Interspeech 2023.
C++
18
star
41

icassp-oov-recognition

Data and code related to the ICASSP submission "A comparison of methods for OOV-word recognition"
C++
17
star
42

asrt

Various scripts that facilitate the preparation of Automatic Speech Recognition related resources
Python
16
star
43

phonvoc

Phonetic and phonological vocoding platform
Shell
16
star
44

fast_pose_machines

Efficient Pose Machine for Multi-Person Pose Estimation
Python
16
star
45

ttgo

A PyTorch implementation of TTGO algorithm and the applications presented in the paper "Tensor Train for Global Optimization Problems in Robotics"
Jupyter Notebook
14
star
46

apam

APAM toolkit is built on PyTorch and provides recipes to adapt pretrained acoustic models with a variety of sequence discriminative training criterions.
Python
14
star
47

torgo_asr

A Kaldi recipe for training automatic speech recognition systems on the Torgo corpus of dysarthric speech
Shell
14
star
48

libssp

Speech Signal Processing - C++ port of a subset of the Python library SSP
C++
13
star
49

cbrec

Content-based Recommendation Generator
Python
13
star
50

wmil-sgd

Weighted multiple-instance learning algorithm based on stochastic gradient descent
Python
12
star
51

bert-text-diarization-atc

Python
12
star
52

DepthInSpace

A PyTorch-based program which estimates 3D depth maps from active structured-light sensor's multiple video frames
Python
11
star
53

iss

Scripts for speech processing
Shell
11
star
54

rgbd

Python
10
star
55

tracter

Tracter is a data flow framework.
C++
10
star
56

pddetection-reps-learning

Supervised Speech Representation Learning for Parkinson's Disease Classification
Python
10
star
57

drill

Deep residual output layers for neural language generation
Python
10
star
58

nvib_transformers

Python
9
star
59

Node_weighted_GCN_for_depression_detection

Node-weighted Graph Convolutional Network for Depression Detection in Transcribed Clinical Interviews
HTML
9
star
60

depth_human_synthesis

DepthHuman: A tool for depth image synthesis for human pose estimation
Python
9
star
61

ilqr_planner

A C++ iLQR library that allows you to solve iLQR optimization problem on any robot as long as you provide an URDF file describing the kinematics chain of the robot
Jupyter Notebook
9
star
62

gafar

Geometry-aware Face Reconstruction
Python
9
star
63

zentas

Partitional data clustering around centers
C++
8
star
64

linear-transformer-experiments

Experiments using fast linear transformer
Python
8
star
65

emorec

Emotion-based Recommendation Generator
OpenEdge ABL
8
star
66

hallucination-detection

Python
8
star
67

DocRec

Keyword extraction and document recommendation in conversations
MATLAB
8
star
68

abroad-re

Towards an end-to-end Relation Extraction system for the natural product literature: datasets, strategies and models
Jupyter Notebook
8
star
69

nvib

Python
7
star
70

cnn-for-voice-antispoofing

CNNs for voice antispoofing detection
MATLAB
7
star
71

wav2vec-lfmmi

Recipes from fine-tuning a pre-trained wav2vec 2.0 model using the espresso tool kit
Python
7
star
72

pydhn

Python
7
star
73

APT

A reference-based metric to evaluate the accuracy of pronoun translation (APT)
Python
6
star
74

iss-dicts

ISS scripts for handling pronunciation dictionaries
Python
6
star
75

sentence-planner

Python
6
star
76

slog

Similarity Learning on Graph (SLOG) matlab codes
MATLAB
6
star
77

cncsharedtask

Jupyter Notebook
6
star
78

inference-from-real-world-sparse-measurements

Implementation of the Multi-Layer Self-Attention, a state-of-the-art model designed for wind nowcasting tasks
Python
6
star
79

ssl-caller-detection

Source code for the paper 'Can Self-Supervised Neural Representations Pre-Trained on Human Speech distinguish Animal Callers?' by E. Sarkar and M. Magimai Doss (2023).
Python
6
star
80

ExVo-2022

Extracting pre-trained self-supervised embeddings for ICML ExVO 2022 challenge
Python
5
star
81

vfoa

Methods to estimate the visual focus of attention
Python
5
star
82

buslr

BuSLR: Build System for Speech and Language Research
CMake
5
star
83

dhgen

A Python module for generating District Heating Networks layouts
Python
5
star
84

bayesian-recurrence

A Bayesian Interpretation of Recurrence in Neural Networks
Python
5
star
85

TactileErgodicExploration

A Python package for ergodic control on point cloud using diffusion. It is supplementary material for the paper "Tactile Ergodic Control Using Diffusion and Geometric Algebra".
Jupyter Notebook
5
star
86

ML3

ML3 classifier (Multiclass Latent Locally Linear Support Vector Machines)
C++
5
star
87

sense_aware_NMT

Sense-aware Neural Machine Translation
Python
5
star
88

php-geremo

PHP Generic Registration Module [GPLv3]
PHP
4
star
89

idiap.github.com

Main page for idiap@github
CSS
4
star
90

tinyurdfparser

A lightweight URDF parser library, based on TinyXML2, that converts an [URDF file] into a KDL object
C++
4
star
91

TIDIGITSRecipe.jl

A Julia recipe for training an ASR system using the TIDIGITS database
Julia
4
star
92

hpca

C
4
star
93

rethinking-saliency

Reference implementation of the ICLR 2021 paper "Rethinking the Role of Gradient-Based Attribution Methods for Model Interpretability".
Python
4
star
94

DiscoConn-Classifier

Classifier models and feature extractors for discourse relations
Perl
4
star
95

pygafro

A geometric algebra library targeted towards robotics applications
Python
4
star
96

anonymization

A Python library for anonymizing sensitive information in text data. Focused on Swiss French banking data.
Python
4
star
97

unsupervised_gaze_calibration

Allows to calibrate a gaze estimator in an unsupervised fashion by automatically collecting calibration samples using task-related priors
Python
4
star
98

Attentive_Residual_Connections_NMT

Implementation and output data of "Global-Context Neural Machine Translation through Target-Side Attentive Residual Connections"
JavaScript
4
star
99

FiniteStateTransducers.jl

Play with Weighted Finite State Transducers (WFST) in the Julia language.
Julia
3
star
100

iss-wsj

ISS scripts for the Wall Street Journal task
Shell
3
star