• Stars
    star
    101
  • Rank 338,166 (Top 7 %)
  • Language
    Python
  • Created over 5 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Semi-Supervised Learning with Ladder Networks in Keras

This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-supervised learning. Refer to the paper titled Semi-Supervised Learning with Ladder Networks by A Rasmus, H Valpola, M Honkala,M Berglund, and T Raiko

This implementation was used in the official code of our paper Unsupervised Clustering using Pseudo-semi-supervised Learning . The code can be found here and the blog post can be found here

The model achives 98% test accuracy on MNIST with just 100 labeled examples.

The code only works with Tensorflow backend.

Requirements

  • Python 2.7+/3.6+
  • Tensorflow (1.4.0)
  • numpy
  • keras (2.1.4)

Note that other versions of tensorflow/keras should also work.

How to use

Load the dataset

from keras.datasets import mnist
import keras
import random

# get the dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 28*28).astype('float32')/255.0
x_test = x_test.reshape(10000, 28*28).astype('float32')/255.0

y_train = keras.utils.to_categorical( y_train )
y_test = keras.utils.to_categorical( y_test )

# only select 100 training samples 
idxs_annot = range( x_train.shape[0])
random.seed(0)
random.shuffle( idxs_annot )
idxs_annot = idxs_annot[ :100 ]

x_train_unlabeled = x_train
x_train_labeled = x_train[ idxs_annot ]
y_train_labeled = y_train[ idxs_annot  ]

Repeat the labeled dataset to match the shapes

n_rep = x_train_unlabeled.shape[0] / x_train_labeled.shape[0]
x_train_labeled_rep = np.concatenate([x_train_labeled]*n_rep)
y_train_labeled_rep = np.concatenate([y_train_labeled]*n_rep)

Initialize the model

from ladder_net import get_ladder_network_fc
inp_size = 28*28 # size of mnist dataset 
n_classes = 10
model = get_ladder_network_fc( layer_sizes = [ inp_size , 1000, 500, 250, 250, 250, n_classes ]  )

Train the model

model.fit([ x_train_labeled_rep , x_train_unlabeled   ] , y_train_labeled_rep , epochs=100)

Get the test accuracy

from sklearn.metrics import accuracy_score
y_test_pr = model.test_model.predict(x_test , batch_size=100 )

print "test accuracy" , accuracy_score(y_test.argmax(-1) , y_test_pr.argmax(-1)  )

More Repositories

1

diffusionbee-stable-diffusion-ui

Diffusion Bee is the easiest way to run Stable Diffusion locally on your M1 Mac. Comes with a one-click installer. No dependencies or technical knowledge needed.
JavaScript
12,342
star
2

image-segmentation-keras

Implementation of Segnet, FCN, UNet , PSPNet and other models in Keras.
Python
2,892
star
3

stable-diffusion-tensorflow

Stable Diffusion in TensorFlow / Keras
Python
1,573
star
4

obsidian-spreadsheets

CSS
128
star
5

deep-clustering-kingdra

Official implementation of ICLR 2020 paper Unsupervised Clustering using Pseudo-semi-supervised Learning
Python
48
star
6

lstm-gender-predictor

Predict the gender of a name using LSTM
Python
41
star
7

attention-translation-keras

Attention based sequence to sequence neural machine translation model built in keras.
Python
30
star
8

sbevnet-stereo-layout-estimation

This repository contains the code for "SBEVNet: End-to-End Deep Stereo Layout Estimation" paper by Divam Gupta, Wei Pu, Trenton Tabor, Jeff Schneider
Python
23
star
9

mtl_girnet

Code and datasets for our AAAI'19 paper : GIRNet: Interleaved Multi-Task Recurrent State Sequence Models
Python
6
star
10

datasets

5
star
11

pytorch-propane

Pytorch Propane is a simplified wrapper to make training and evaluation of neural networks easy and scalable.
Python
5
star
12

mttdsc

Code for our PAKDD'19 paper "Multi-task Learning for Target-dependent Sentiment Classification"
Python
5
star
13

PliceFS

Minimal file system implemented in C++.
C++
2
star
14

cppshell

Minimal shell implemented in C++
C++
2
star
15

divamgupta.github.io

HTML
2
star