• This repository has been archived on 04/Jan/2023
  • Stars
    star
    601
  • Rank 74,537 (Top 2 %)
  • Language
    Python
  • Created over 6 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch Boilerplate For Research

Model Statistics

Number of Parameters

num_params = sum(p.numel() for p in model.parameters()) # Total parameters
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)  # Trainable parameters

Number of FLOPS

[...]

Weight Initialization

PyTorch layers are initialized by default in their respective reset_parameters() method. For example:

  • nn.Linear
    • weight and bias: uniform distribution [-limit, +limit] where limit is 1. / sqrt(fan_in) and fan_in is the number of input units in the weight tensor.
  • nn.Conv2D
    • weight and bias: uniform distribution [-limit, +limit] where limit is 1. / sqrt(fan_in) and fan_in is the number of input units in the weight tensor.

With this implementation, the variance of the layer outputs is equal to Var(W) = 1 / 3 * sqrt(fan_in) which isn't the best initialization strategy out there.

Note that PyTorch provides convenience functions for some of the initializations. The input and output shapes are computed using the method _calculate_fan_in_and_fan_out() and a gain() method scales the standard deviation to suit a particular activation.

Xavier Initialization

This initialization is general-purpose and meant to "work" pretty well for any activation in practice.

# default xavier init
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform(m.weight)

You can tailor this initialization to your specific activation by using the nn.init.calculate_gain(act) argument.

# default xavier init
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform(m.weight, gain=nn.init.calculate_gain('relu'))

He et. al Initialization

This is a similarly derived initialization tailored specifically for ReLU activations since they do not exhibit zero mean.

# he initialization
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal(m.weight, mode='fan_in')

For mode=fan_in, the variance of the distribution is ensured in the forward pass, while for mode=fan_out, it is ensured in the backwards pass.

SELU Initialization

Again, this initialization is specifically derived for the SELU activation function. The authors use the fan_in strategy. They mention that there is no significant difference between sampling from a Gaussian, a truncated Gaussian or a Uniform distribution.

# selu init
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
        nn.init.normal(m.weight, 0, sqrt(1. / fan_in))
    elif isinstance(m, nn.Linear):
        fan_in = m.in_features
        nn.init.normal(m.weight, 0, sqrt(1. / fan_in))

Orthogonal Initialization

Orthogonality is a desirable quality in NN weights in part because it is norm preserving, i.e. it rotates the input matrix, but cannot change its norm (scale/shear). This property is valuable in deep or recurrent networks, where repeated matrix multiplication can result in signals vanishing or exploding.

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.orthogonal(m.weight)

Batch Norm Initialization

for m in model:
    if isinstance(m, nn.BatchNorm2d):
        nn.init.constant(m.weight, 1)
        nn.init.constant(m.bias, 0)

Weight Regularization

L2 Regularization

Heavily penalizes peaky weight vectors and encourages diffuse weight vectors. Has the appealing property of encouraging the network to use all of its inputs a little rather that some of its inputs a lot.

with torch.enable_grad():
    reg = 1e-6
    l2_loss = torch.zeros(1)
    for name, param in model.named_parameters():
        if 'bias' not in name:
            l2_loss = l2_loss + (0.5 * reg * torch.sum(torch.pow(W, 2)))

L1 Regularization

Encourages sparsity, meaning we encourage the network to select the most useful inputs/features rather than use all.

with torch.enable_grad():
    reg = 1e-6
    l1_loss = torch.zeros(1)
    for name, param in model.named_parameters():
        if 'bias' not in name:
            l1_loss = l1_loss + (reg * torch.sum(torch.abs(W)))

Orthogonal Regularization

Improves gradient flow by keeping the matrix norm close to unitary.

with torch.enable_grad():
    reg = 1e-6
    orth_loss = torch.zeros(1)
    for name, param in model.named_parameters():
        if 'bias' not in name:
            param_flat = param.view(param.shape[0], -1)
            sym = torch.mm(param_flat, torch.t(param_flat))
            sym -= torch.eye(param_flat.shape[0])
            orth_loss = orth_loss + (reg * sym.abs().sum())

Max Norm Constraint

If a hidden unit's weight vector's L2 norm L ever gets bigger than a certain max value c, multiply the weight vector by c/L. Enforce it immediately after each weight vector update or after every X gradient update.

This constraint is another form of regularization. While L2 penalizes high weights using the loss function, "max norm" acts directly on the weights. L2 exerts a constant pressure to move the weights near zero which could throw away useful information when the loss function doesn't provide incentive for the weights to remain far from zero. On the other hand, "max norm" never drives the weights to near zero. As long as the norm is less than the constraint value, the constraint has no effect.

def max_norm(model, max_val=3, eps=1e-8):
    for name, param in model.named_parameters():
        if 'bias' not in name:
            norm = param.norm(2, dim=0, keepdim=True)
            desired = torch.clamp(norm, 0, max_val)
            param = param * (desired / (eps + norm))

Batch Normalization

[...]

Dropout

[...]

Optimization Misc.

Correct Validation Strategies

[...]

References

  • Thanks to Zijun Deng for inspiring the code for the segmentation metrics.

More Repositories

1

spatial-transformer-network

A Tensorflow implementation of Spatial Transformer Networks.
Python
978
star
2

recurrent-visual-attention

A PyTorch Implementation of "Recurrent Models of Visual Attention"
Python
468
star
3

torchnca

A PyTorch implementation of Neighbourhood Components Analysis.
Python
400
star
4

mjctrl

Minimal, clean, single-file implementations of common robotics controllers in MuJoCo.
Python
204
star
5

mink

Python inverse kinematics based on MuJoCo
Python
184
star
6

obj2mjcf

A CLI for processing composite Wavefront OBJ files for use in MuJoCo.
Python
155
star
7

torchkit

Research boilerplate for PyTorch.
Python
150
star
8

mujoco_scanned_objects

MuJoCo Models for Google's Scanned Objects Dataset
145
star
9

clip_playground

An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities
Jupyter Notebook
144
star
10

tsne-viz

Python Wrapper for t-SNE Visualization
Python
126
star
11

ibc

A PyTorch implementation of Implicit Behavioral Cloning
Python
93
star
12

form2fit

[ICRA 2020] Train generalizable policies for kit assembly with self-supervised dense correspondence learning.
Python
82
star
13

blog-code

My blog's code repository.
Jupyter Notebook
76
star
14

learn-linalg

Learning some numerical linear algebra.
Python
70
star
15

dexterity

Software and tasks for dexterous multi-fingered hand manipulation, powered by MuJoCo
Python
59
star
16

x-magical

[CoRL 2021] A robotics benchmark for cross-embodiment imitation.
Python
58
star
17

mjc_viewer

A browser-based 3D viewer for MuJoCo
Python
55
star
18

torchsdf-fusion

Benchmarking PyTorch variants of TSDF fusion.
Python
47
star
19

robopianist-rl

RL code for training piano-playing policies for RoboPianist.
Python
42
star
20

mujoco_tips_and_tricks

32
star
21

walle

My robotics research toolkit.
Python
22
star
22

mujoco_cube

A 3x3x3 puzzle cube model for MuJoCo.
Python
21
star
23

coffee

Infrastructure for PyBullet research
Python
20
star
24

robopianist-demo

C
20
star
25

learn-ransac

Learning about RANSAC.
Python
19
star
26

dm_env_wrappers

Standalone library of frequently-used wrappers for dm_env environments.
Python
18
star
27

root-locus

Python implementation of the Root Locus method.
Python
17
star
28

nanorl

A tiny reinforcement learning codebase for continuous control, built on top of JAX.
Python
12
star
29

software

My open-source software contributions.
9
star
30

kinetics

Python script to mine the Kinetics dataset.
Python
6
star
31

cloneformer

BC with Transformers
Python
5
star
32

mujoco_utils

Python
5
star
33

learn-blur

Learning about various image blurring techniques.
Python
3
star
34

pymenagerie

Composer classes for MuJoCo Menagerie models.
Python
3
star
35

learn-volumetric-fusion

Learning about volumetric fusion.
Python
2
star