• Stars
    star
    166
  • Rank 220,651 (Top 5 %)
  • Language
    Python
  • License
    Other
  • Created over 6 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[NO MAINTENANCE INTENDED] A PyTorch implementation of CapsNet architecture in the NIPS 2017 paper "Dynamic Routing Between Capsules".

PyTorch CapsNet: Capsule Network for PyTorch

license completion No Maintenance Intended

A CUDA-enabled PyTorch implementation of CapsNet (Capsule Network) based on this paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017

The current test error is 0.21% and the best test error is 0.20%. The current test accuracy is 99.31% and the best test accuracy is 99.32%.

What is a Capsule

A Capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or object part.

You can learn more about Capsule Networks here.

Why another CapsNet implementation?

I wanted a decent PyTorch implementation of CapsNet and I couldn't find one at the point when I started. The goal of this implementation is focus to help newcomers learn and understand the CapsNet architecture and the idea of Capsules. The implementation is NOT focus on rigorous correctness of the results. In addition, the codes are not optimized for speed. To help us read and understand the codes easier, the codes comes with ample comments and the Python classes and functions are documented with Python docstring.

I will try my best to check and fix issues reported. Contributions are highly welcomed. If you find any bugs or errors in the codes, please do not hesitate to open an issue or a pull request. Thank you.

Status and Latest Updates:

See the CHANGELOG

Datasets

The model was trained on the standard MNIST data.

Note: you don't have to manually download, preprocess, and load the MNIST dataset as TorchVision will take care of this step for you.

I have tried using other datasets. See the Other Datasets section below for more details.

Requirements

  • Python 3
    • Tested with version 3.6.4
  • PyTorch
    • Tested with version 0.3.0.post4
    • Migrate existing code to work in version 0.4.0. [Work-In-Progress]
    • Code will not run with version 0.1.2 due to keepdim not available in this version.
    • Code will not run with version 0.2.0 due to softmax function doesn't takes a dimension.
  • CUDA 8 and above
    • Tested with CUDA 8 and CUDA 9.
  • TorchVision
  • tensorboardX
  • tqdm

Usage

Training and Evaluation

Step 1. Clone this repository with git and install project dependencies.

$ git clone https://github.com/cedrickchee/capsule-net-pytorch.git
$ cd capsule-net-pytorch
$ pip install -r requirements.txt

Step 2. Start the CapsNet on MNIST training and evaluation:

  • Training with default settings:
$ python main.py
  • Training on 8 GPUs with 30 epochs and 1 routing iteration:
$ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py --epochs 30 --num-routing 1 --threads 16 --batch-size 128 --test-batch-size 128

Step 3. Test a pre-trained model:

If you have trained a model in Step 2 above, then the weights for the trained model will be saved to results/trained_model/model_epoch_10.pth. [WIP] Now just run the following command to get test results.

$ python main.py --is-training 0 --weights results/trained_model/model_epoch_10.pth

Pre-trained Model

You can download the weights for the pre-trained model from my Google Drive. We saved the weights (model state dict) and the optimizer state for the model at the end of every training epoch.

Uncompress and put the weights (.pth files) into ./results/trained_model/.

Note: the model was last trained on 2017-11-26 and the weights last updated on 2017-11-28.

The Default Hyper Parameters

Parameter Value CLI arguments
Training epochs 10 --epochs 10
Learning rate 0.01 --lr 0.01
Training batch size 128 --batch-size 128
Testing batch size 128 --test-batch-size 128
Log interval 10 --log-interval 10
Disables CUDA training false --no-cuda
Num. of channels produced by the convolution 256 --num-conv-out-channel 256
Num. of input channels to the convolution 1 --num-conv-in-channel 1
Num. of primary unit 8 --num-primary-unit 8
Primary unit size 1152 --primary-unit-size 1152
Num. of digit classes 10 --num-classes 10
Output unit size 16 --output-unit-size 16
Num. routing iteration 3 --num-routing 3
Use reconstruction loss true --use-reconstruction-loss
Regularization coefficient for reconstruction loss 0.0005 --regularization-scale 0.0005
Dataset name (mnist, cifar10) mnist --dataset mnist
Input image width to the convolution 28 --input-width 28
Input image height to the convolution 28 --input-height 28

Results

Test Error

CapsNet classification test error on MNIST. The MNIST average and standard deviation results are reported from 3 trials.

The results can be reproduced by running the following commands.

 python main.py --epochs 50 --num-routing 1 --use-reconstruction-loss no --regularization-scale 0.0       #CapsNet-v1
 python main.py --epochs 50 --num-routing 1 --use-reconstruction-loss yes --regularization-scale 0.0005   #CapsNet-v2
 python main.py --epochs 50 --num-routing 3 --use-reconstruction-loss no --regularization-scale 0.0       #CapsNet-v3
 python main.py --epochs 50 --num-routing 3 --use-reconstruction-loss yes --regularization-scale 0.0005   #CapsNet-v4
Method Routing Reconstruction MNIST (%) Paper
Baseline -- -- -- 0.39
CapsNet-v1 1 no -- 0.34 (0.032)
CapsNet-v2 1 yes -- 0.29 (0.011)
CapsNet-v3 3 no -- 0.35 (0.036)
CapsNet-v4 3 yes 0.21 0.25 (0.005)

Training Loss and Accuracy

The training losses and accuracies for CapsNet-v4 (50 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):

Training accuracy. Highest training accuracy: 100%

Training loss. Lowest training error: 0.1938%

Test Loss and Accuracy

The test losses and accuracies for CapsNet-v4 (50 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):

Test accuracy. Highest test accuracy: 99.32%

Test loss. Lowest test error: 0.2002%

Training Speed

  • Around 5.97s / batch or 8min / epoch on a single Tesla K80 GPU with batch size of 704.
  • Around 3.25s / batch or 25min / epoch on a single Tesla K80 GPUwith batch size of 128.

In my case, these are the hyperparameters I used for the training setup:

  • batch size: 128
  • Epochs: 50
  • Num. of routing: 3
  • Use reconstruction loss: yes
  • Regularization scale for reconstruction loss: 0.0005

Reconstruction

The results of CapsNet-v4.

Digits at left are reconstructed images.

[WIP] Ground truth image from dataset

Model Design

Model architecture:
------------------

Net (
  (conv1): ConvLayer (
    (conv0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (relu): ReLU (inplace)
  )
  (primary): CapsuleLayer (
    (conv_units): ModuleList (
      (0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (1): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (2): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (3): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (4): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (5): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (6): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
      (7): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
    )
  )
  (digits): CapsuleLayer (
  )
  (decoder): Decoder (
    (fc1): Linear (160 -> 512)
    (fc2): Linear (512 -> 1024)
    (fc3): Linear (1024 -> 784)
    (relu): ReLU (inplace)
    (sigmoid): Sigmoid ()
  )
)

Parameters and size:
-------------------

conv1.conv0.weight: [256, 1, 9, 9]
conv1.conv0.bias: [256]
primary.conv_units.0.weight: [32, 256, 9, 9]
primary.conv_units.0.bias: [32]
primary.conv_units.1.weight: [32, 256, 9, 9]
primary.conv_units.1.bias: [32]
primary.conv_units.2.weight: [32, 256, 9, 9]
primary.conv_units.2.bias: [32]
primary.conv_units.3.weight: [32, 256, 9, 9]
primary.conv_units.3.bias: [32]
primary.conv_units.4.weight: [32, 256, 9, 9]
primary.conv_units.4.bias: [32]
primary.conv_units.5.weight: [32, 256, 9, 9]
primary.conv_units.5.bias: [32]
primary.conv_units.6.weight: [32, 256, 9, 9]
primary.conv_units.6.bias: [32]
primary.conv_units.7.weight: [32, 256, 9, 9]
primary.conv_units.7.bias: [32]
digits.weight: [1, 1152, 10, 16, 8]
decoder.fc1.weight: [512, 160]
decoder.fc1.bias: [512]
decoder.fc2.weight: [1024, 512]
decoder.fc2.bias: [1024]
decoder.fc3.weight: [784, 1024]
decoder.fc3.bias: [784]

Total number of parameters on (with reconstruction network): 8227088 (8 million)

TensorBoard

We logged the training and test losses and accuracies using tensorboardX. TensorBoard helps us visualize how the machine learn over time. We can visualize statistics, such as how the objective function is changing or weights or accuracy varied during training.

TensorBoard operates by reading TensorFlow data (events files).

How to Use TensorBoard

  1. Download a copy of the events files for the latest run from my Google Drive.
  2. Uncompress the file and put it into ./runs.
  3. Check to ensure you have installed tensorflow (CPU version). We need this for TensorBoard server and dashboard.
  4. Start TensorBoard.
$ tensorboard --logdir runs
  1. Open TensorBoard dashboard in your web browser using this URL: http://localhost:6006

Other Datasets

CIFAR10

In the spirit of experiment, I have tried using other datasets. I have updated the implementation so that it supports and works with CIFAR10. Need to note that I have not tested throughly our capsule model on CIFAR10.

Here's how we can train and test the model on CIFAR10 by running the following commands.

python main.py --dataset cifar10 --num-conv-in-channel 3 --input-width 32 --input-height 32 --primary-unit-size 2048 --epochs 80 --num-routing 1 --use-reconstruction-loss yes --regularization-scale 0.0005
Training Loss and Accuracy

The training losses and accuracies for CapsNet-v4 (80 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):

  • Highest training accuracy: 100%
  • Lowest training error: 0.3589%
Test Loss and Accuracy

The test losses and accuracies for CapsNet-v4 (80 epochs, 3 routing iteration, using reconstruction, regularization scale of 0.0005):

  • Highest test accuracy: 71%
  • Lowest test error: 0.5735%

TODO

  • Publish results.
  • More testing.
  • Inference mode - command to test a pre-trained model.
  • Jupyter Notebook version.
  • Create a sample to show how we can apply CapsNet to real-world application.
  • Experiment with CapsNet:
    • Try using another dataset.
    • Come out a more creative model structure.
  • Pre-trained model and weights.
  • Add visualization for training and evaluation metrics.
  • Implement recontruction loss.
  • Check algorithm for correctness.
  • Update results from TensorBoard after making improvements and bug fixes.
  • Publish updated pre-trained model weights.
  • Log the original and reconstructed images using TensorBoard.
  • Update results with reconstructed image and original image.
  • Resume training by loading model checkpoint.
  • Migrate existing code to work in PyTorch 0.4.0.

WIP is an acronym for Work-In-Progress

Credits

Referenced these implementations mainly for sanity check:

  1. TensorFlow implementation by @naturomics

Learning Resources

Here's some resources that we think will be helpful if you want to learn more about Capsule Networks:

Other Implementations

Real-world Application of CapsNet

The following is a few samples in the wild that show how we can apply CapsNet to real-world use cases.

More Repositories

1

awesome-transformer-nlp

A curated list of NLP resources focused on Transformer networks, attention mechanism, GPT, BERT, ChatGPT, LLMs, and transfer learning.
904
star
2

awesome-wireguard

A curated list of WireGuard tools, projects, and resources.
398
star
3

awesome-ml-model-compression

Awesome machine learning model compression research papers, tools, and learning material.
328
star
4

chatgpt-universe

ChatGPT Universe is fleeting notes on ChatGPT, GPT, and large language models (LLMs)
288
star
5

knowledge

Everything I know. My knowledge wiki. My notes (mostly for fast.ai). Document everything. Brain dump.
123
star
6

pytorch-android

[EXPERIMENTAL] Demo of using PyTorch 1.0 inside an Android app. Test with your own deep neural network such as ResNet18/SqueezeNet/MobileNet v2 and a phone camera.
C++
103
star
7

data-science-notebooks

Data science Python notebooks—a collection of Jupyter notebooks on machine learning, deep learning, statistical inference, data analysis and visualization.
Jupyter Notebook
87
star
8

react-typescript-jest-enzyme-testing

Testing React.JS + TypeScript component with Jest and Enzyme. A simple example for reference.
TypeScript
52
star
9

saas-starter

Everything you need to get your next Unicorn-For-X startup off the ground.
JavaScript
43
star
10

realtime-detectron

Real-time Detectron using webcam.
Python
41
star
11

transformers-llama

LLaMA implementation for HuggingFace Transformers
Python
25
star
12

YDKGo

You Don't Know Go Yet book.
Go
24
star
13

rnnoise-nodejs

Node.js bindings to Xiph's RNNoise denoising C library
Rust
17
star
14

e-mart

Open source full stack React and Next.js online mart complete with shopping cart and real credit checkout.
JavaScript
16
star
15

pytorch-serving

[UNMAINTAINED] A starter pack for creating a lightweight responsive web app for Fast.AI PyTorch models.
Python
16
star
16

ssd-yolo-retinanet

Multi-class object detection pipeline—Single Shot MultiBox Detector (SSD) + YOLOv3 (real-time) + focal loss (RetinaNet) + Pascal VOC 2007 dataset
Python
16
star
17

pytorch-lite

[Deprecated] PyTorch Lite is a lightweight machine learning framework for on-device mobile inference.
Jupyter Notebook
14
star
18

tch-js

A JavaScript and TypeScript port of PyTorch C++ library (libtorch) - Node.js N-API bindings for libtorch.
C++
9
star
19

pytorch-mobile-kit

PyTorch Mobile starter kit.
Java
8
star
20

experiments

A collection of little snippets of programs I write when I test out ideas. A code "playground".
Go
8
star
21

wasserstein-gan

PyTorch implementation of Wasserstein GAN paper
Jupyter Notebook
6
star
22

neural-network-in-13-lines

A neural network in 13 lines of Python.
Python
6
star
23

openintercom

An open source modern Intercom alternative.
JavaScript
5
star
24

react-18-beta

React 18 Beta (Suspense, concurrent rendering, HTTP streaming, Server Components) + Next.js 12.0.4 demo & benchmark (performance & UX)
JavaScript
5
star
25

painless-pg-node

Painless PostgreSQL Node.js backend with Objection + Knex + Express
JavaScript
4
star
26

fastai-course-v3

My notebooks for the 3rd edition of course.fast.ai - coming in 2019
Jupyter Notebook
4
star
27

kaggle-facial-detection

Facial keypoints detection challenge tutorial and solution for Singapore Kaggle ML Challenge meetup.
Jupyter Notebook
3
star
28

neocargo

neoCargo microservices in Go with PostgreSQL, MongoDB, Terraform, Google Kubernetes Engine, and CircleCI
Go
3
star
29

nodejs-in-depth

Master and understand deeper Node.js fundamentals and internals
JavaScript
3
star
30

dawnbench-analysis

DAWNBench analysis of CIFAR-10 time-to-accuracy.
Jupyter Notebook
2
star
31

hou

Hou 🐒 programming language interpreter and compiler
Go
2
star
32

myapp

A ruby on rails app experiments
Ruby
2
star
33

kafka-eventsourcing-restapi

REST API service using Apache Kafka for event sourcing
Go
2
star
34

snippetbox

A web app to paste and share snippets of text
Go
2
star
35

learn-ts-handbook

Learn TypeScript in 2021 by reading the Handbook.
TypeScript
2
star
36

fastai-dl2-2017

My notebooks for fast.ai cutting edge deep learning for coders part 2 2017 course.
Jupyter Notebook
2
star
37

data-science-hacks

A collection of notebooks for engineer practicing machine learning / deep learning through hacking project-based learning.
Jupyter Notebook
2
star
38

feed

A feed of things I'm reading and will read. It's sort of like bookmarks or favorites.
2
star
39

min-torrent

Yet another minimalistic torrent client
Go
2
star
40

advent-of-code-2022

Advent of Code (AoC) 2022 in Rust
Rust
2
star
41

migraine_diary

My personal migraine log.
Ruby
1
star
42

bitcask

My key/value store (embedded database) solution for PingCAP training courses
Rust
1
star
43

personal-website

The personal website of Cedric Chee
HTML
1
star
44

postgresql-consul-demo

A minimal demo app showing PostgreSQL HA cluster managed by Patroni and Consul in Docker
Python
1
star
45

rails323_testing

An app to learn about what's new in Rails 3.2.3 compared to 3.0.10.
Ruby
1
star
46

rl-algorithms

A collection of Reinforcement Learning algorithms.
Jupyter Notebook
1
star
47

tensorflow-community-builds

TensorFlow prebuilt binary (Python wheels) from source by the community.
1
star
48

shuttlecock

Badminton news & results
1
star
49

first_app

Testing new workstation for rails+git installation
Ruby
1
star
50

todos

Super simple todos app. Develop using Rails 3 to learn the changes to Rails 2.3.8.
Ruby
1
star
51

squidgame

[WIP] Red Light, Green Light game inspired by Squid Game implemented in Rust, TypeScript & WebSocket. Play in your browser, multiplayer (at least 2 players), and tiny.
Rust
1
star
52

skel

Skel is an idiomatic and flexible code structure for REST API project - practical code patterns and best practices for building (developing, managing, and deploying) APIs in Go.
Go
1
star
53

amethyst

My digital garden
SCSS
1
star
54

dockerfile-fastai

Dockerfile for building NVIDIA CUDA image for PyTorch 1.0 and fastai 1.0 deep learning
Dockerfile
1
star
55

soshiok

A full stack restaurant app
JavaScript
1
star