• Stars
    star
    194
  • Rank 200,219 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 6 years ago
  • Updated over 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Batch normalization fusion for PyTorch

Batch Norm Fusion for Pytorch

About

In this repository, we present a simplistic implementation of batchnorm fusion for the most popular CNN architectures in PyTorch. This package is aimed to speed up the inference at the test time: expected boost is 30%! In the future

How it works

We know that both - convolution and batchnorm are the linear operations to the data point x, and they can be written in terms of matrix multiplications: T_{bn}*S{bn}Conv_W(x), where we first apply convolution to the data, scale it and eventually shift it using the batchnorm-trained parameters.

Supported architectures

We support any architecture, where Conv and BN are combined in a Sequential module. If you want to optimize your own networks with this tool, just follow this design. For the conveniece, we wrapped VGG, ResNet and SeNet families to demonstrate how your models can be converted into such format.

  • VGG from torchvision.
  • ResNet Family from torchvision.
  • SeNet family from pretrainedmodels

How to use

import torchvision.models as models
from bn_fusion import fuse_bn_recursively

net = getattr(models,'vgg16_bn')(pretrained=True)
net = fuse_bn_recursively(net)
net.eval()
# Make inference with the converted model

TODO

  • Tests.
  • Performance benchmarks.

Acknowledgements

Thanks to @ZFTurbo for the idea, discussions and his implementation for Keras.

More Repositories

1

solt

Streaming over lightweight data transformations
Jupyter Notebook
263
star
2

OAProgression

Multimodal Machine Learning-based Knee Osteoarthritis Progression Prediction from Plain Radiographs and Clinical Data
Python
75
star
3

DeepKnee

Codes for paper: Automatic Knee Osteoarthritis Diagnosis from Plain Radiographs: A Deep Learning-Based Approach
Python
64
star
4

KNEEL

Hourglass Networks for Knee Anatomical Landmark Localization: PyTorch Implementation
Python
43
star
5

stambo

Statistical model comparison with bootstrap and beyond
Jupyter Notebook
24
star
6

KneeOARSIGrading

Grading individual knee osteoarthritis features using Deep Learning
Python
15
star
7

AdaTriplet

Python
14
star
8

CLIMATv2

CLIMATv2: Clinically-Inspired Multi-Agent Transformers for Disease Trajectory Forecasting from Multimodal Data
Python
13
star
9

KneeLocalizer

Codes for paper: A novel method for automatic localization of joint area on knee plain radiographs by A. Tiulpin et. al.
Python
12
star
10

semixup

Semixup: In- and Out-of-Manifold Regularization for Deep Semi-Supervised Knee Osteoarthritis Severity Grading from Plain Radiographs
Python
12
star
11

greedy_ensembles_training

Greedy Bayesian Posterior Approximation with Deep Ensembles. A. Tiulpin and M. B. Blaschko. (2021)
Python
11
star
12

DeepWrist

Deep Learning Pipeline for Wrist Fracture Detection
Python
10
star
13

LoG-VMamba

LoG-VMamba: Local-Global Vision Mamba for Medical Image Segmentation
Python
7
star
14

CLIMAT

CLIMAT: Clinically-Inspired Multi-Agent Transformers for Disease Trajectory Forecasting from Multi-modal Data
Python
6
star
15

OAProgressionMR

Prediction of knee osteoarthritis progression from DESS MRI
Python
5
star
16

SAUNA

SAUNA: Image-level Regression for Uncertainty-aware Retinal Image Segmentation
Python
2
star
17

PfirrmannGrading

Python
1
star
18

SiNGR

SiNGR: Brain Tumor Segmentation via Signed Normalized Geodesic Transform Regression
Python
1
star