• Stars
    star
    359
  • Rank 118,537 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 2 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for our NeurIPS 2022 paper

Gradient Descent: The Ultimate Optimizer

gdtuo_turtles

Abstract

Working with any gradient-based machine learning algorithm involves the tedious task of tuning the optimizer's hyperparameters, such as the step size. Recent work has shown how the step size can itself be "learned" on-line by gradient descent, by manually deriving expressions for "hypergradients" ahead of time.

We show how to automatically compute hypergradients with a simple and elegant modification to backpropagation. This allows us to apply the method to other hyperparameters besides the step size, such as the momentum coefficient. We can even recursively apply the method to its own hyper-hyperparameters, and so on ad infinitum. As these towers of optimizers grow taller, they become less sensitive to the initial choice of hyperparameters. We present experiments validating this for MLPs, CNNs, and RNNs.

This repository contains an implementation of the algorithm in our paper.

Citation

@article{chandra2022gradient,
    title = {Gradient Descent: The Ultimate Optimizer},
    author = {Chandra, Kartik and Xie, Audrey and Ragan-Kelley, Jonathan and Meijer, Erik},
    journal = {NeurIPS},
    year = {2022},
    url = {https://arxiv.org/abs/1909.13371}
}

Install

# install pytorch for your specific machine
...

# install our package
pip install gradient-descent-the-ultimate-optimizer

Example

First, build the MLP and initialize data loaders as you would normally in PyTorch.

import math
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

class MNIST_FullyConnected(nn.Module):
    """
    A fully-connected NN for the MNIST task. This is Optimizable but not itself
    an optimizer.
    """
    def __init__(self, num_inp, num_hid, num_out):
        super(MNIST_FullyConnected, self).__init__()
        self.layer1 = nn.Linear(num_inp, num_hid)
        self.layer2 = nn.Linear(num_hid, num_out)

    def initialize(self):
        nn.init.kaiming_uniform_(self.layer1.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.layer2.weight, a=math.sqrt(5))

    def forward(self, x):
        """Compute a prediction."""
        x = self.layer1(x)
        x = torch.tanh(x)
        x = self.layer2(x)
        x = torch.tanh(x)
        x = F.log_softmax(x, dim=1)
        return x

BATCH_SIZE = 256
EPOCHS = 5
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=10000, shuffle=False)

model = MNIST_FullyConnected(28 * 28, 128, 10).to(DEVICE)

Next, import our package and initialize a stack of hyperoptimizers. This example uses the stack Adam/SGD.

from gradient_descent_the_ultimate_optimizer import gdtuo

optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5))

gdtuo.ModuleWrapper allows any nn.Module to be optimized by hyperoptimizers.

mw = gdtuo.ModuleWrapper(model, optimizer=optim)
mw.initialize()

Lastly, use mw instead of a PyTorch optimizer to optimize the model. The train loop is nearly identical to what you would typically implement in PyTorch (differences are marked by comments).

for i in range(1, EPOCHS+1):
    running_loss = 0.0
    for j, (features_, labels_) in enumerate(dl_train):
        mw.begin() # call this before each step, enables gradient tracking on desired params
        features, labels = torch.reshape(features_, (-1, 28 * 28)).to(DEVICE), labels_.to(DEVICE)
        pred = mw.forward(features)
        loss = F.nll_loss(pred, labels)
        mw.zero_grad()
        loss.backward(create_graph=True) # important! use create_graph=True
        mw.step()
        running_loss += loss.item() * features_.size(0)
    train_loss = running_loss / len(dl_train.dataset)
    print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss))

Note that on the first step of the train loop PyTorch will return the following warning:

UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak.

This is normal and to be expected.

More Repositories

1

nearley

📜🔜🌲 Simple, fast, powerful parser toolkit for JavaScript.
JavaScript
3,584
star
2

recreational-rosette

Some fun examples of solving problems with symbolic execution
Racket
109
star
3

tower-of-power

What is hip? Tell me, tell me (if you think you know)
Python
108
star
4

memo

A language for mental models
Jupyter Notebook
28
star
5

designing-perceptual-puzzles-by-differentiating-probabilistic-programs

Supplementary materials for our SIGGRAPH 2022 paper
Jupyter Notebook
27
star
6

shabdle

Shabdle is Wordle in Hindi
Mathematica
24
star
7

neural-ambigrams

Generating digits that are secretly *other* digits doing handstands
HTML
19
star
8

acting-as-inverse-inverse-planning

Code for our SIGGRAPH 2023 paper, "Acting as Inverse Inverse Planning"
Jupyter Notebook
16
star
9

chaos-game-fractal-foliage

Code for "Learning to Play the Chaos Game: Dreaming of fractal foliage by differentiating iterated function systems"
Jupyter Notebook
14
star
10

Snapin8r

A Scratch 2.0->Snap! converter
JavaScript
13
star
11

eddie

An automatic first-order theorem prover in Haskell
Haskell
12
star
12

emo.7

The man page for emoticons.
Groff
9
star
13

torchsaber

Elegant dimensions for a more civilized age
Python
8
star
14

kesar

A Python library for quickly building human subject studies
Python
8
star
15

shock

It's simple… it's static… it's shock!
JavaScript
8
star
16

jigsaw

An Escher-esque jigsaw puzzle generator
Standard ML
8
star
17

prufrock

A literary proof assistant built on the affine calculus of inductive constraints
Standard ML
8
star
18

gifblocks

Make animated GIFs from your Scratch projects!
HTML
8
star
19

haskell-lambda-calculus

A simple lambda calculus interpreter in Haskell.
Haskell
7
star
20

hamelin

Revenge of the The Py'd Piper
Python
7
star
21

lowtex

Low-tech text processing
JavaScript
7
star
22

baobab

Interactive Fiction with Racket and love
JavaScript
6
star
23

voxel

Yet another raytracer, because we don't have enough of those already
Scheme
5
star
24

sublime-nearley

A nearley syntax plugin for TextMate/Sublime Text
Makefile
5
star
25

softraxterizer

A small softras implemented in JAX
Python
4
star
26

hell

Because http://xkcd.com/724/
JavaScript
4
star
27

hootow-hyperlapse

Using classic computer vision algorithms to align hundreds of images of Hoover Tower
Jupyter Notebook
4
star
28

quackoverflow

Meow.
JavaScript
3
star
29

boxcars

Lively box-and-pointer diagrams for Racket.
Racket
3
star
30

bellhop

A simple, informative bell schedule app.
JavaScript
3
star
31

optimally-framing-roger-rabbit

accelerating accelerators with differentiable kD-trees
Python
3
star
32

englipsum

Usable loremtext
JavaScript
2
star
33

Legible

Regexes for Humans
JavaScript
2
star
34

Human

Humanist documentation
Python
2
star
35

ketchup

A temporal RSS proxy
JavaScript
2
star
36

turtlegrad

Bidirectional programming by gradient descent
JavaScript
2
star
37

scotty

Specify Characters On a TTY - Readline for binary input
C
2
star
38

watchat

Racket
2
star
39

jeopardy-wagering-under-uncertainty

Python
1
star
40

indexme

Quick, portable, versatile directory listing generator
Shell
1
star
41

jokebot

A simple demo IRC bot in Python
Python
1
star
42

poison-ivy

Create a graphical representation of dependency relationships between Ivy conjectures.
Python
1
star
43

generative-adversarial-web-development

Substance… and style!
HTML
1
star
44

lagrange-climbs-a-hill

Interpolating Lagrangian mechanics by AD and gradient descent
Python
1
star
45

bentley-blizzard-blossoms

A frosty MNIST alternative :)
Jupyter Notebook
1
star