• Stars
    star
    1,401
  • Rank 33,554 (Top 0.7 %)
  • Language
    Python
  • License
    MIT License
  • Created over 6 years ago
  • Updated 7 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Bayesian Convolutional Neural Network with Variational Inference based on Bayes by Backprop in PyTorch.

Python 3.7+ Pytorch 1.3 License: MIT arxiv

We introduce Bayesian convolutional neural networks with variational inference, a variant of convolutional neural networks (CNNs), in which the intractable posterior probability distributions over weights are inferred by Bayes by Backprop. We demonstrate how our proposed variational inference method achieves performances equivalent to frequentist inference in identical architectures on several datasets (MNIST, CIFAR10, CIFAR100) as described in the paper.


Filter weight distributions in a Bayesian Vs Frequentist approach

Distribution over weights in a CNN's filter.


Fully Bayesian perspective of an entire CNN

Distributions must be over weights in convolutional layers and weights in fully-connected layers.


Layer types

This repository contains two types of bayesian lauer implementation:

  • BBB (Bayes by Backprop):
    Based on this paper. This layer samples all the weights individually and then combines them with the inputs to compute a sample from the activations.

  • BBB_LRT (Bayes by Backprop w/ Local Reparametrization Trick):
    This layer combines Bayes by Backprop with local reparametrization trick from this paper. This trick makes it possible to directly sample from the distribution over activations.


Make your custom Bayesian Network?

To make a custom Bayesian Network, inherit layers.misc.ModuleWrapper instead of torch.nn.Module and use BBBLinear and BBBConv2d from any of the given layers (BBB or BBB_LRT) instead of torch.nn.Linear and torch.nn.Conv2d. Moreover, no need to define forward method. It'll automatically be taken care of by ModuleWrapper.

For example:

class Net(nn.Module):

  def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(3, 16, 5, strides=2)
    self.bn = nn.BatchNorm2d(16)
    self.relu = nn.ReLU()
    self.fc = nn.Linear(800, 10)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    x = x.view(-1, 800)
    x = self.fc(x)
    return x

Above Network can be converted to Bayesian as follows:

class Net(ModuleWrapper):

  def __init__(self):
    super().__init__()
    self.conv = BBBConv2d(3, 16, 5, strides=2)
    self.bn = nn.BatchNorm2d(16)
    self.relu = nn.ReLU()
    self.flatten = FlattenLayer(800)
    self.fc = BBBLinear(800, 10)

Notes:

  1. Add FlattenLayer before first BBBLinear block.
  2. forward method of the model will return a tuple as (logits, kl).
  3. priors can be passed as an argument to the layers. Default value is:
priors={
    'prior_mu': 0,
    'prior_sigma': 0.1,
    'posterior_mu_initial': (0, 0.1),  # (mean, std) normal_
    'posterior_rho_initial': (-3, 0.1),  # (mean, std) normal_
}

How to perform standard experiments?

Currently, following datasets and models are supported.

  • Datasets: MNIST, CIFAR10, CIFAR100
  • Models: AlexNet, LeNet, 3Conv3FC

Bayesian

python main_bayesian.py

  • set hyperparameters in config_bayesian.py

Frequentist

python main_frequentist.py

  • set hyperparameters in config_frequentist.py

Directory Structure:

layers/: Contains ModuleWrapper, FlattenLayer, BBBLinear and BBBConv2d.
models/BayesianModels/: Contains standard Bayesian models (BBBLeNet, BBBAlexNet, BBB3Conv3FC).
models/NonBayesianModels/: Contains standard Non-Bayesian models (LeNet, AlexNet).
checkpoints/: Checkpoint directory: Models will be saved here.
tests/: Basic unittest cases for layers and models.
main_bayesian.py: Train and Evaluate Bayesian models.
config_bayesian.py: Hyperparameters for main_bayesian file.
main_frequentist.py: Train and Evaluate non-Bayesian (Frequentist) models.
config_frequentist.py: Hyperparameters for main_frequentist file.


Uncertainty Estimation:

There are two types of uncertainties: Aleatoric and Epistemic.
Aleatoric uncertainty is a measure for the variation of data and Epistemic uncertainty is caused by the model.
Here, two methods are provided in uncertainty_estimation.py, those are 'softmax' & 'normalized' and are respectively based on equation 4 from this paper and equation 15 from this paper.
Also, uncertainty_estimation.py can be used to compare uncertainties by a Bayesian Neural Network on MNIST and notMNIST dataset. You can provide arguments like:

  1. net_type: lenet, alexnet or 3conv3fc. Default is lenet.
  2. weights_path: Weights for the given net_type. Default is 'checkpoints/MNIST/bayesian/model_lenet.pt'.
  3. not_mnist_dir: Directory of notMNIST dataset. Default is 'data\'.
  4. num_batches: Number of batches for which uncertainties need to be calculated.

Notes:

  1. You need to download the notMNIST dataset from here.
  2. Parameters layer_type and activation_type used in uncertainty_etimation.py needs to be set from config_bayesian.py in order to match with provided weights.

If you are using this work, please cite:

@article{shridhar2019comprehensive,
  title={A comprehensive guide to bayesian convolutional neural network with variational inference},
  author={Shridhar, Kumar and Laumann, Felix and Liwicki, Marcus},
  journal={arXiv preprint arXiv:1901.02731},
  year={2019}
}
@article{shridhar2018uncertainty,
  title={Uncertainty estimations by softplus normalization in bayesian convolutional neural networks with variational inference},
  author={Shridhar, Kumar and Laumann, Felix and Liwicki, Marcus},
  journal={arXiv preprint arXiv:1806.05978},
  year={2018}
}
}

More Repositories

1

Master-Thesis-BayesianCNN

Master Thesis on Bayesian Convolutional Neural Network using Variational Inference
TeX
262
star
2

Know-Your-Intent

State of the Art results in Intent Classification using Sematic Hashing for three datasets: AskUbuntu, Chatbot and WebApplication.
Jupyter Notebook
134
star
3

Screws

SCREWS: A Modular Framework for Reasoning with Revisions
Python
26
star
4

CNN_Architectures

Keras Implementation of major CNN architectures
Jupyter Notebook
15
star
5

ProbAct-Probabilistic-Activation-Function

Official PyTorch implementation of the paper : ProbAct: A Probabilistic Activation Function for Deep Neural Networks.
Jupyter Notebook
13
star
6

Very-Deep-Learning-CNN

Everything you need to know about CNN in PyTorch
Jupyter Notebook
11
star
7

PyTorch-Super-Resolution

Super Resolution of low resolution Images in PyTorch
Python
9
star
8

Indian-Names-Generator-Using-Deep-Learning

Generate Indian people Name using LSTM
Python
9
star
9

LongtoNotes

LongtoNotes: OntoNotes with Longer Coreference Chains
8
star
10

PyTorch-Bayesian-Super-Resolution

Super Resolution using Bayesian CNN
Python
7
star
11

NER-Benchmarks

NER Benchmark
7
star
12

Distiiling-LM

The code for the paper : Distilling Reasoning Capabilities into Smaller Language Models
6
star
13

MastersInMachineLearningOnline

Online Content equivalent to Machine Learning Masters Degree at an University
4
star
14

CV

My Curriculum vitae (CV)
3
star
15

GMM-vs-K-Means

Gausian Mixture Models vs K-Means
Jupyter Notebook
2
star
16

PyTorch-Bayesian-DCGAN

Bayesian Version of DCGAN
2
star
17

Detect-Language

Detect the language from the given sentence
Python
2
star
18

Online-Toxicity-Detection

APOLLO-1: Online Toxicity Detection
Python
2
star
19

Twitter_Sentiment_Analysis

Analyse sentiments of tweets posted on Twitter
Python
1
star
20

HackathonLulea

hack night Lulea
Jupyter Notebook
1
star
21

AddressExtractorAPI

Extract address using the API
Python
1
star
22

Very-Deep-Learning-NLP

NLP exercise for Very Deep Learning Lecture at TU Kl
Jupyter Notebook
1
star
23

LTU_NLP

NLP Projects at LTU
1
star
24

GoT-Data-Visualization-using-t-SNE

Visualize multidimensional Game of Thrones data-set using t-SNE.
Python
1
star
25

QWERTY-based-Mistake-Probability

Add noise in training data with QWERTY based Mistake Probability
Jupyter Notebook
1
star
26

Business-Card-Detector

Detect if an image is a business card or not!
Python
1
star