• Stars
    star
    117
  • Rank 301,828 (Top 6 %)
  • Language
    Python
  • Created over 5 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implements stochastic line search

Sls - Stochastic Line Search (NeurIPS2019) [paper][video]

Train faster and better with the SLS optimizer. The following 3 steps are there for getting started.

1. Installation

pip install git+https://github.com/IssamLaradji/sls.git

2. Usage

Use Sls in your code by adding the following script.

import sls
opt = sls.Sls(model.parameters())

for epoch in range(100):
      # create loss closure
      closure = lambda : torch.nn.MSELoss()(model(X), y)

      # update parameters
      opt.zero_grad()
      loss = opt.step(closure=closure)

3. Experiments

Install the experiment requirements pip install -r requirements.txt

2.1 MNIST

python trainval.py -e mnist -sb ../results -d ../data -r 1

where -e is the experiment group, -sb is the result directory, and -d is the dataset directory.

2.2 Cifar100 experiment

python trainval.py -e cifar100 -sb ../results -d ../data -r 1

3. Results

3.1 Launch Jupyter by running the following on terminal,

jupyter nbextension enable --py widgetsnbextension --sys-prefix
jupyter notebook

3.2 On a Jupyter cell, run the following script,

from haven import haven_jupyter as hj
from haven import haven_results as hr
from haven import haven_utils as hu

# path to where the experiments got saved
savedir_base = '../results'

# filter exps
filterby_list = [{'dataset':'cifar100', 'opt':{'c':0.5}}, 
                 {'dataset':'cifar100', 'opt':{'name':'adam'}}]
                 
# get experiments
rm = hr.ResultManager(savedir_base=savedir_base, 
                      filterby_list=filterby_list, 
                      verbose=0)
                      
# dashboard variables
legend_list = ['opt.name']
title_list = ['dataset', 'model']
y_metrics = ['train_loss', 'val_acc']

# launch dashboard
hj.get_dashboard(rm, vars(), wide_display=True)

alt text

Citation

@inproceedings{vaswani2019painless,
  title={Painless stochastic gradient: Interpolation, line-search, and convergence rates},
  author={Vaswani, Sharan and Mishkin, Aaron and Laradji, Issam and Schmidt, Mark and Gidel, Gauthier and Lacoste-Julien, Simon},
  booktitle={Advances in Neural Information Processing Systems},
  pages={3727--3740},
  year={2019}
}

It is a collaborative work between labs at MILA, Element AI, and UBC.

More Repositories

1

sps

Official code for the Stochastic Polyak step-size optimizer
Python
136
star
2

M-ADDA

Domain Adaptation Based on the Triplet Loss
Python
111
star
3

NeuralNetworks

Multi-layer Perceptron
Python
86
star
4

covid19_weak_supervision

WACV2021 - A Weakly Supervised Consistency-based Learning Method for COVID-19 Segmentation in CT Images
Python
39
star
5

BlockCoordinateDescent

Code for the paper "Let’s Make Block Coordinate Descent Go Fast"
Python
39
star
6

CBStyling

Styling individual objects in an image
Python
25
star
7

affinity_lcfcn

Python
22
star
8

cownter_strike

Python
21
star
9

GP_DRF

Official code for "Efficient Deep Gaussian Process Models for Variable-Sized Inputs" - accepted in IJCNN2019
Python
15
star
10

looc

Python
9
star
11

ada_sls

Python
9
star
12

wisenet

Python
8
star
13

ssn

Official code for Stochastic Second Order Methods under Interpolation paper
Jupyter Notebook
7
star
14

elastic_transform

Python
6
star
15

demo_image_classification

Python
3
star
16

mirror-sps

Python
3
star
17

chat_summarization_tutorial

Python
3
star
18

PrettyPlots

Python
3
star
19

SSR

Python
3
star
20

extreme-learning-machines

Python
3
star
21

CoordinateDescent_ICML2015

Python
2
star
22

FCN_CRF

Python
2
star
23

Kaczmarz_UAI2016

Python
2
star
24

diff_rendering

Python
2
star
25

semantic_segmentation_template

Python
2
star
26

image_classification_template

Python
2
star
27

IssamLaradji

2
star
28

sparse_single_layer_network

Python
1
star
29

randomized_neural_networks

Python
1
star
30

PSO_NN

Python
1
star
31

MASAGA

Official code for MASAGA
Python
1
star
32

Boat_Detection_WRFC8

Python
1
star
33

HomeworkGrader

Python
1
star
34

semantic_segmentation

Python
1
star