• Stars
    star
    285
  • Rank 145,115 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created over 6 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of One-Cycle Learning rate policy (adapted from Fast.ai lib)

One Cycle Learning Rate Policy for Keras

Implementation of One-Cycle Learning rate policy from the papers by Leslie N. Smith.

Contains two Keras callbacks, LRFinder and OneCycleLR which are ported from the PyTorch Fast.ai library.

What is One Cycle Learning Rate

It is the combination of gradually increasing learning rate, and optionally, gradually decreasing the momentum during the first half of the cycle, then gradually decreasing the learning rate and optionally increasing the momentum during the latter half of the cycle.

Finally, in a certain percentage of the end of the cycle, the learning rate is sharply reduced every epoch.

The Learning rate schedule is visualized as :

The Optional Momentum schedule is visualized as :

Usage

Finding a good learning rate

Use LRFinder to obtain a loss plot, and visually inspect it to determine the initial loss plot. Provided below is an example, used for the MiniMobileNetV2 model.

An example script has been provided in find_lr_schedule.py inside the models/mobilenet/.

Essentially,

from clr import LRFinder

lr_callback = LRFinder(num_samples, batch_size,
                       minimum_lr, maximum_lr,
                       # validation_data=(X_val, Y_val),
                       lr_scale='exp', save_dir='path/to/save/directory')

# Ensure that number of epochs = 1 when calling fit()
model.fit(X, Y, epochs=1, batch_size=batch_size, callbacks=[lr_callback])

The above callback does a few things.

  • Must supply number of samples in the dataset (here, 50k from CIFAR 10) and the batch size that will be used during training.
  • lr_scale is set to exp - useful when searching over a large range of learning rates. Set to linear to search a smaller space.
  • save_dir - Automatic saving of the results of LRFinder on some directory path specified. This is highly encouraged.
  • validation_data - provide the validation data as a tuple to use that for the loss plot instead of the training batch loss. Since the validation dataset can be very large, we will randomly sample k batches (k * batch_size) from the validation set to provide quick estimate of the validation loss. The default value of k can be changed by changing validation_sample_rate

Note : When using this, be careful about setting the learning rate, momentum and weight decay schedule. The loss plots will be more erratic due to the sampling of the validation set.

NOTE 2 :

  • It is faster to get the learning rate without using validation_data, and then find the weight decay and momentum based on that learning rate while using validation_data.
  • You can also use LRFinder to find the optimal weight decay and momentum values using the examples find_momentum_schedule.py and find_weight_decay_schedule.py inside models/mobilenet/ folder.

To visualize the plot, there are two ways -

  • Use lr_callback.plot_schedule() after the fit() call. This uses the current training session results.
  • Use class method LRFinder.plot_schedule_from_dir('path/to/save/directory') to visualize the plot separately from the training session. This only works if you used the save_dir argument to save the results of the search to some location.

Finding the optimal Momentum

Use the find_momentum_schedule.py script inside models/mobilenet/ for an example.

Some notes :

  • Use a grid search over a few possible momentum values, such as [0.8, 0.85, 0.9, 0.95, 0.99]. Use linear as the lr_scale argument value.

  • Set the momentum value manually to the SGD optimizer before compiling the model.

  • Plot the curve at the end and visually see which momentum value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve.

  • It is better to supply the validation_data here.

  • The plot will be very noisy, so if you wish, can use a larger value of loss_smoothing_beta (such as 0.99 or 0.995)

  • The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates.

Finding the optimal Weight Decay

Use the find_weight_decay_schedule.py script inside models/mobilenet/ for an example

Some notes :

  • Use a grid search over a few weight decay values, such as [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]. Call this "coarse search" and use linear for the lr_scale argument.

  • Use a grid search over a select few weight decay values, such as [3e-7, 1e-7, 3e-6]. Call this "fine search" and use linear scale for the lr_scale argument.

  • Set the weight decay value manually to the model when building the model.

  • Plot the curve at the end and visually see which weight decay value yields the least noisy / lowest losses overall on the plot. The absolute value of the loss plot is not very important as much as the curve.

  • It is better to supply the validation_data here.

  • The plot will be very noisy, so if you wish, can use a larger value of loss_smoothing_beta (such as 0.99 or 0.995)

  • The actual curve values doesnt matter as much as what is overall curve movement. Choose the value which is more steady and tries to get the lowest value even at large learning rates.

Interpreting the plot

Learning Rate

Consider the above plot from using the LRFinder on the MiniMobileNetV2 model. In particular, there are a few regions above that we need to carefully interpret.

Note : The values are in log 10 scale (since exp was used for lr_scale) ; All values discussed will be based on the x-axis (learning rate) :

  • After the -1.5 point on the graph, the loss becomes erratic
  • After the 0.5 point on the graph, the loss is noisy but doesn't decrease any further.
  • -1.7 is the last relatively smooth portion before the -1.5 region. To be safe, we can choose to move a little more to the left, closer to -1.8, but this will reduce the performance.
  • It is usually important to visualize the first 2-3 epochs of OneCycleLR training with values close to these edges to determine which is the best.

Momentum

Using the above learning rate, use this information to next calculate the optimal momentum (find_momentum_schedule.py)

See the notes in the Finding the optimal momentum section on how to interpret the plot.

Weight Decay

Similarly, it is possible to use the above learning rate and momentum values to calculate the optimal weight decay (find_weight_decay_schedule.py).

Note : Due to large learning rates acting as a strong regularizer, other regularization techniques like weight decay and dropout should be decreased significantly to properly train the model.

It is best to search a range of regularization strength between 1e-3 to 1e-7 first, and then fine-search the region that provided the best overall plot.

See the notes in the Finding the optimal weight decay section on how to interpret the plot.

Training with OneCycleLR

Once we find the maximum learning rate, we can then move onto using the OneCycleLR callback with SGD to train our model.

from clr import OneCycleLR

lr_manager = OneCycleLR(num_samples, num_epoch, batch_size, max_lr
                        end_percentage=0.1, scale_percentage=None,
                        maximum_momentum=0.95, minimum_momentum=0.85)
                        
model.fit(X, Y, epochs=EPOCHS, batch_size=batch_size, callbacks=[model_checkpoint, lr_manager], 
          ...)

There are many parameters, but a few of the important ones :

  • Must provide a lot of training information - number of samples, number of epochs, batch size and max learning rate
  • end_percentage is used to determine what percentage of the training epochs will be used for steep reduction in the learning rate. At its miminum, the lowest learning rate will be calculated as 1/1000th of the max_lr provided.
  • scale_percentage is a confusing parameter. It dictates the scaling factor of the learning rate in the second half of the training cycle. It is best to test this out visually using the plot_clr.py script to ensure there are no mistakes. Leaving it as None defaults to using the same percentage as the provided end_percentage.
  • maximum/minimum_momentum are preset according to the paper and Fast.ai. However, if you don't wish to scale it, set both to the same value, generally 0.9 is preferred as the momentum value for SGD. If you don't want to update the momentum / are not using SGD (not adviseable) - set both to None to ignore the momentum updates.

Results

  • -1.7 is chosen to be the maximum learning rate (in log10 space) for the OneCycleLR schedule. Since this is in log10 scale, we use 10 ^ (x) to get the actual learning maximum learning rate. Here, 10 ^ -1.7 ~ 0.019999. Therefore, we round up to a maximum learning rate of 0.02
  • 0.9 is chosen as the maximum momentum from the momentum plot. Using Cyclic Momentum updates, choose a slightly lower value (0.85) as the minimum for faster training.
  • 3e-6 is chosen as the the weight decay factor.

For the MiniMobileNetV2 model, 2 passes of the OneCycle LR with SGD (40 epochs - max lr = 0.02, 30 epochs - max lr = 0.005) obtained 90.33%. This may not seem like much, but this is a model with only 650k parameters, and in comparison, the same model trained on Adam with initial learning rate 2e-3 did not converge to the same score in over 100 epochs (89.14%).

Requirements

  • Keras 2.1.6+
  • Tensorflow (tested) / Theano / CNTK for the backend
  • matplotlib to visualize the plots.

More Repositories

1

Neural-Style-Transfer

Keras Implementation of Neural Style Transfer from the paper "A Neural Algorithm of Artistic Style" (http://arxiv.org/abs/1508.06576) in Keras 2.0+
Jupyter Notebook
2,271
star
2

Image-Super-Resolution

Implementation of Super Resolution CNN in Keras.
Python
832
star
3

neural-image-assessment

Implementation of NIMA: Neural Image Assessment in Keras
Python
780
star
4

LSTM-FCN

Codebase for the paper LSTM Fully Convolutional Networks for Time Series Classification
Python
755
star
5

DenseNet

DenseNet implementation in Keras
Python
706
star
6

MLSTM-FCN

Multivariate LSTM Fully Convolutional Networks for Time Series Classification
Python
490
star
7

neural-architecture-search

Basic implementation of [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578).
Python
431
star
8

keras-squeeze-excite-network

Implementation of Squeeze and Excitation Networks in Keras
Python
400
star
9

Inception-v4

Inception-v4, Inception - Resnet-v1 and v2 Architectures in Keras
Python
385
star
10

Keras-Classification-Models

Collection of Keras models used for classification
Python
317
star
11

Snapshot-Ensembles

Snapshot Ensemble in Keras
Python
305
star
12

keras-non-local-nets

Keras implementation of Non-local Neural Networks
Python
290
star
13

Super-Resolution-using-Generative-Adversarial-Networks

An implementation of SRGAN model in Keras
Python
283
star
14

tf-TabNet

A Tensorflow 2.0 implementation of TabNet.
Python
238
star
15

Keras-ResNeXt

Implementation of ResNeXt models from the paper Aggregated Residual Transformations for Deep Neural Networks in Keras 2.0+.
Python
224
star
16

tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
Python
218
star
17

Keras-NASNet

"NASNet" models in Keras 2.0+ with weights
Python
200
star
18

keras-efficientnets

Keras Implementation of EfficientNets
Python
187
star
19

tf_SIREN

Tensorflow 2.0 implementation of Sinusodial Representation networks (SIREN)
Python
149
star
20

keras-coordconv

Keras implementation of CoordConv for all Convolution layers
Python
148
star
21

MobileNetworks

Keras implementation of Mobile Networks
Python
132
star
22

keras-adabound

Keras implementation of AdaBound
Python
130
star
23

progressive-neural-architecture-search

Implementation of Progressive Neural Architecture Search in Keras and Tensorflow
Python
120
star
24

keras-attention-augmented-convs

Keras implementation of Attention Augmented Convolutional Neural Networks
Python
120
star
25

Keras-DualPathNetworks

Dual Path Networks for Keras 2.0+
Python
114
star
26

Wide-Residual-Networks

Wide Residual Networks in Keras
Python
112
star
27

Fast-Neural-Style

Implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" in Keras
Python
109
star
28

Keras-Group-Normalization

A Keras implementation of https://arxiv.org/abs/1803.08494
Python
103
star
29

BatchRenormalization

Batch Renormalization algorithm implementation in Keras
Python
98
star
30

Nested-LSTM

Keras implementation of Nested LSTMs
Python
90
star
31

keras-SRU

Implementation of Simple Recurrent Unit in Keras
Python
89
star
32

Fully-Connected-DenseNets-Semantic-Segmentation

Fully Connected DenseNet for Image Segmentation (https://arxiv.org/pdf/1611.09326v1.pdf)
Python
84
star
33

keras-LAMB-Optimizer

Implementation of the LAMB optimizer for Keras from the paper "Reducing BERT Pre-Training Time from 3 Days to 76 Minutes"
Python
76
star
34

tf-eager-examples

A set of simple examples ported from PyTorch for Tensorflow Eager Execution
Jupyter Notebook
73
star
35

keras_rectified_adam

Implementation of Rectified Adam in Keras
Python
69
star
36

Keras-IndRNN

Implementation of IndRNN in Keras
Python
67
star
37

LSTM-FCN-Ablation

Repository for the ablation study of "Long Short-Term Memory Fully Convolutional Networks for Time Series Classification"
Python
55
star
38

keras-octconv

Keras implementation of Octave Convolutions
Python
53
star
39

keras-global-context-networks

Keras implementation of Global Context Attention blocks
Python
46
star
40

Neural-Style-Transfer-Windows

Windows Form application written in C# to ease usage of neural style transfer script
Python
43
star
41

tf_fourier_features

Tensorflow 2.0 implementation of Fourier Feature Mapping Networks.
Python
42
star
42

Keras-Multiplicative-LSTM

Miltiplicative LSTM for Keras 2.0+
Python
42
star
43

keras_mixnets

Keras Implementation of MixNets: Mixed Depthwise Convolutions
Python
39
star
44

Keras-just-another-network-JANET

Keras implementation of [The unreasonable effectiveness of the forget gate](https://arxiv.org/abs/1804.04849)
Jupyter Notebook
35
star
45

keras-switchnorm

Switch Normalization implementation for Keras 2+
Python
30
star
46

keras-neural-alu

A Keras implementation of Neural Arithmatic and Logical Unit
Python
27
star
47

keras-mobile-colorizer

U-Net Model conditioned with MobileNet features for Grayscale -> Color mapping
Python
25
star
48

Deep-Columnar-Convolutional-Neural-Network

Deep Columnar Convolutional Neural Network architecture, which is based on Multi Columnar DNN (Ciresan 2012).
Python
24
star
49

keras-SparseNet

Keras Implementation of SparseNets
Python
23
star
50

Residual-of-Residual-Networks

Residual Network of Residual Networks in Keras
Python
22
star
51

pyshac

A Python library for the Sequential Halving and Classification algorithm
Python
21
star
52

keras_novograd

Keras implementation of NovoGrad
Python
20
star
53

Adversarial-Attacks-Time-Series

Codebase for the paper "Adversarial Attacks on Time Series"
Python
20
star
54

simple_diffusion

Simple notebooks to learn diffusion models on toy datasets
Jupyter Notebook
17
star
55

keras-normalized-optimizers

Wrapper for Normalized Gradient Descent in Keras
Jupyter Notebook
17
star
56

keras-padam

Keras implementation of Padam from "Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks"
Python
17
star
57

pytorch_odegan

Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations
Python
16
star
58

tf-sha-rnn

Tensorflow port implementation of Single Headed Attention RNN
Python
16
star
59

warprnnt_numba

WarpRNNT loss ported in Numba CPU/CUDA for Pytorch
Jupyter Notebook
16
star
60

Advanced_Machine_Learning

Python
16
star
61

dtw-numba

Implementation of Dynamic Time Warping algorithm with speed improvements based on Numba.
Python
16
star
62

keras-minimal-rnn

Keras implementation of MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks
Python
16
star
63

TweetSentimentAnalysis

CS583 course project
Python
14
star
64

lambda_networks_pt

Lambda Networks implemented in PyTorch
Python
13
star
65

tf_GON

Tensorflow 2.x implementation of Gradient Origin Networks
Python
13
star
66

tf_neural_deconvolution

Neural Deconvolutions in Tensorflow
Python
12
star
67

Python-Work

Python scripts to facilitate easy working
Jupyter Notebook
11
star
68

PyCTakesParser

Utilities to parse the output of cTAKES
Python
10
star
69

tf_star_rnn

Tensorflow 2.0 implementation of STAR RNN
Python
10
star
70

Deep-Dream

Deep Dream implementation in Keras
Python
9
star
71

Kaggle

Kaggle competition library. Uses Python 3.4.1 with almost all known python libraries for Machine Learning
Python
7
star
72

Music-Recognition

C# project to perform Frequency Analysis of music
C#
5
star
73

Rabin-Karp-String-Matching

C
4
star
74

Data-Science

Library of Data Science classes
Python
3
star
75

diffusion_model_nemo

Python
3
star
76

Ragial-Searcher

The Core Java library used to parse and store Ragial.com data
HTML
3
star
77

MSApriori

Multiple support apriori algorithm in Java
Java
3
star
78

RagialNotifier

Android App to parse ragial.com using the Ragial Searcher library to track items and notify the user if the item is on sale. Developed for the game Ragnarok Online, developed and owned by Gravity Inc.
Java
3
star
79

IDS-Course-Project

Intro to Data Science Project
Python
2
star
80

ML-Tools

Python
2
star
81

braindrain-uncommonhacks

JavaScript
2
star
82

Tiger-Game

Tiger Game in Python 2.7 / 3.4+
Python
2
star
83

8086-Microprocessor

An attempt to emulate an 8086 microprocessor, with its ASM instruction set.
Java
2
star
84

titu1994.github.io

HTML
2
star
85

Adaptive-Sorting-Algorithm

Analysis and implementation of Machine Learning Decision Tree to classify best algorithm for given data set
C#
2
star
86

Optimal-Binary-Search-Tree

C
2
star
87

Naive-String-Matching

C
2
star
88

Recurstion-C

Recursion in C
C
2
star
89

Java-Adaptive-Sorting-Algorithm

Adaptive Sorting Algorithm using Decision Trees to decide which algorithm will be optimal to sort a given dataset.
Java
2
star
90

Quick-Sort

Quick Sort in Java
1
star
91

Rate-Monotonic-Scheduling-Algorithm

Java
1
star
92

WT-Mini-Project

CSS
1
star
93

Kruskals-Algorithm

C
1
star
94

Stack

Stack
C
1
star
95

Doublu-Linked-List

Doubly Linked List
C
1
star
96

CircularLinkedList

Circular Linked List in C
C
1
star
97

Knuth-Morris-Pratt

C
1
star
98

MyLib

1
star
99

Polynomial-Linked-List

Polynomial Linked List
C
1
star
100

SOOAD-Mini-Project

Java
1
star