• Stars
    star
    343
  • Rank 123,371 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated 10 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/

PyTorch implementation of Normalizer-Free Networks and Adaptive Gradient Clipping

Python Package Docs Papers using ma-gym

Paper: https://arxiv.org/abs/2102.06171.pdf

Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Blog post: https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/. Feel free to subscribe to the newsletter, and leave a comment if you have anything to add/suggest publicly.

Do star this repository if it helps your work, and don't forget to cite if you use this code in your research!

Installation

Install from PyPi:

pip3 install nfnets-pytorch

or install the latest code using:

pip3 install git+https://github.com/vballoli/nfnets-pytorch

Usage

WSConv2d

Use WSConv1d, WSConv2d, ScaledStdConv2d(timm) and WSConvTranspose2d like any other torch.nn.Conv2d or torch.nn.ConvTranspose2d modules.

import torch
from torch import nn
from nfnets import WSConv2d, WSConvTranspose2d, ScaledStdConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

conv_t = nn.ConvTranspose2d(3,6,3)
w_conv_t = WSConvTranspose2d(3,6,3)

Generic AGC (recommended)

import torch
from torch import nn, optim
from torchvision.models import resnet18

from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing

# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])

SGD - Adaptive Gradient Clipping

Similarly, use SGD_AGC like torch.optim.SGD

# The generic AGC is preferable since the paper recommends not applying AGC to the last fc layer.
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)

Using it within any non-residual PyTorch model (with non-residual connections)

replace_conv replaces the convolution in your (non-residual) model with the convolution class and replaces the batchnorm with identity. While the identity is not ideal, it shouldn't cause a major difference in the latency.

Note that as per the paper, replace_conv is only valid for non-residual models(vgg, mobilenetv1, etc.). See the above mentioned blog post for more information regarding the details.

import torch
from torch import nn
from torchvision.models import vgg16

from nfnets import replace_conv, WSConv2d, ScaledStdConv2d

model = vgg16()
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm

"""
class YourCustomClass(nn.Conv2d):
  ...
replace_conv(model, YourCustomClass)
"""

Docs

Find the docs at readthedocs

Cite Original Work

To cite the original paper, use:

@article{brock2021high,
  author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  title={High-Performance Large-Scale Image Recognition Without Normalization},
  journal={arXiv preprint arXiv:},
  year={2021}
}

Cite this repository

To cite this repository, use:

@misc{nfnets2021pytorch,
  author = {Vaibhav Balloli},
  title = {A PyTorch implementation of NFNets and Adaptive Gradient Clipping},
  year = {2021},
  howpublished = {\url{https://github.com/vballoli/nfnets-pytorch}}
}

More Repositories

1

vit-flax

Implementation of Vision Transformers in Flax
Python
16
star
2

ENAS-CNN

Implementation of ENAS for CNNs on CIFAR 10
Python
10
star
3

robust-representation-random-convolutions

Robust and Generalizable Visual Representation Learning via Random Convolutions in PyTorch
Python
8
star
4

sam

Generic Sharpness Aware Minimization wrapper in PyTorch
Python
4
star
5

abel-pytorch

ABEL implemented in PyTorch
Python
3
star
6

compression-framework-pytorch

Reproducing An End-to-End Compression Framework Based on Convolutional Neural Network in PyTorch
Python
3
star
7

movie-recommender-system

CS F469 Information Retrieval Course Project - Movie recommender system
Python
2
star
8

flipkart-search

CSF469 Information Retrieval Assignment - Domain specific search engine for mobiles including specs and reviews sold on flipkart.
Python
2
star
9

arithmetic-intensity

Arithmetic Intensity calculator for PyTorch models
Python
1
star
10

PPO.jl

Proximal Policy Optimization implemented in Julia using Flux.jl
Julia
1
star
11

swift-meta

Meta learning library in Swift for Tensorflow
Swift
1
star
12

mips-processor

Un-pipelined partial MIPS processor implementation in Verilog
Verilog
1
star
13

bezier-curve

Bezier curve implementation in OpenGL
C
1
star
14

Ripple

Ripple is a PyTorch library for Risk-aware ML
Python
1
star
15

GradCAM.jl

Julia implementation of GradCAM - Visualization of what CNNs see
Julia
1
star
16

offlax

Offline Reinforcement Learning Framework in JAX
Python
1
star
17

blog-old

Being Codest - My Personal Blog deployed on Github pages powered by Ghost and Buster.
HTML
1
star
18

Swift-Jupyter-Starter

Contains essentials for prototyping on a jupyter notebook in a Swift environment.
Jupyter Notebook
1
star
19

vballoli.github.io

Personal Academic Website
HTML
1
star
20

Project592

EECS 592 Course Project code
Jupyter Notebook
1
star
21

decentralised-chat

Solidity based Decentralised chat app built using Node.js
JavaScript
1
star