• Stars
    star
    257
  • Rank 158,728 (Top 4 %)
  • Language
    Python
  • Created over 8 years ago
  • Updated over 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

MNIST digit classification with scikit-learn and Support Vector Machine (SVM) algorithm.

SVM MNIST digit classification in python using scikit-learn

The project presents the well-known problem of MNIST handwritten digit classification. For the purpose of this tutorial, I will use Support Vector Machine (SVM) the algorithm with raw pixel features. The solution is written in python with use of scikit-learn easy to use machine learning library.

Sample MNIST digits visualization

The goal of this project is not to achieve the state of the art performance, rather teach you how to train SVM classifier on image data with use of SVM from sklearn. Although the solution isn't optimized for high accuracy, the results are quite good (see table below).

If you want to hit the top performance, this two resources will show you current state of the art solutions:

The table below shows some results in comparison with other models:

Method Accuracy Comments
Random forest 0.937
Simple one-layer neural network 0.926
Simple 2 layer convolutional network 0.981
SVM RBF 0.9852 C=5, gamma=0.05
Linear SVM + Nystroem kernel approximation
Linear SVM + Fourier kernel approximation

Project Setup

This tutorial was written and tested on Ubuntu 18.10. Project contains the Pipfile with all necessary libraries

  • Python - version >= 3.6
  • pipenv - package and virtual environment management
  • numpy
  • matplotlib
  • scikit-learn
  1. Install Python.
  2. Install pipenv
  3. Git clone the repository
  4. Install all necessary python packages executing this command in terminal
git clone https://github.com/ksopyla/svm_mnist_digit_classification.git
cd svm_mnist_digit_classification
pipenv install

Solution

In this tutorial, I use two approaches to SVM learning. First, uses classical SVM with RBF kernel. The drawback of this solution is rather long training on big datasets, although the accuracy with good parameters is high. The second, use Linear SVM, which allows for training in O(n) time. In order to achieve high accuracy, we use some trick. We approximate RBF kernel in a high dimensional space by embeddings. The theory behind is quite complicated, however sklearn has ready to use classes for kernel approximation. We will use:

  • Nystroem kernel approximation
  • Fourier kernel approximation

The code was tested with python 3.6.

How the project is organized

Project consist of three files:

  • mnist_helpers.py - contains some visualization functions: MNIST digits visualization and confusion matrix
  • svm_mnist_classification.py - script for SVM with RBF kernel classification
  • svm_mnist_embedings.py - script for linear SVM with embedings

SVM with RBF kernel

The svm_mnist_classification.py script downloads the MNIST database and visualizes some random digits. Next, it standardizes the data (mean=0, std=1) and launch grid search with cross-validation for finding the best parameters.

  1. MNIST SVM kernel RBF Param search C=[0.1,0.5,1,5], gamma=[0.01,0.0.05,0.1,0.5].

Grid search was done for params C and gamma, where C=[0.1,0.5,1,5], gamma=[0.01,0.0.05,0.1,0.5]. I have examined only 4x4 different param pairs with 3 fold cross validation so far (4x4x3=48 models), this procedure takes 3687.2min :) (2 days, 13:56:42.531223 exactly) on one core CPU.

Param space was generated with numpy logspace and outer matrix multiplication.

C_range = np.outer(np.logspace(-1, 0, 2),np.array([1,5]))
# flatten matrix, change to 1D numpy array
C_range = C_range.flatten()

gamma_range = np.outer(np.logspace(-2, -1, 2),np.array([1,5]))
gamma_range = gamma_range.flatten()

Of course, you can broaden the range of parameters, but this will increase the computation time.

SVM RBF param space

Grid search is very time consuming process, so you can use my best parameters (from the range c=[0.1,5], gamma=[0.01,0.05]):

  • C = 5
  • gamma = 0.05
  • accuracy = 0.9852
Confusion matrix:
[[1014    0    2    0    0    2    2    0    1    3]
 [   0 1177    2    1    1    0    1    0    2    1]
 [   2    2 1037    2    0    0    0    2    5    1]
 [   0    0    3 1035    0    5    0    6    6    2]
 [   0    0    1    0  957    0    1    2    0    3]
 [   1    1    0    4    1  947    4    0    5    1]
 [   2    0    1    0    2    0 1076    0    4    0]
 [   1    1    8    1    1    0    0 1110    2    4]
 [   0    4    2    4    1    6    0    1 1018    1]
 [   3    1    0    7    5    2    0    4    9  974]]
Accuracy=0.985238095238
  1. MNIST SVM kernel RBF Param search C=[0.1,0.5,1,5, 10, 50], gamma=[0.001, 0.005, 0.01,0.0.05,0.1,0.5].

This much broaden search 6x8 params with 3 fold cross validation gives 6x8x3=144 models, this procedure takes 13024.3min (9 days, 1:33:58.999782 exactly) on one core CPU.

SVM RBF param space

Best parameters:

  • C = 5
  • gamma = 0.05
  • accuracy = 0.9852

Linear SVM with different embeddings

Linear SVM's (SVM with linear kernels) have this advantages that there are many O(n) training algorithms. They are really fast in comparison with other nonlinear SVM (where most of them are O(n^2)). This technique is really useful if you want to train on big data.

Linear SVM algortihtms examples(papers and software):

Unfortunately, linear SVM isn't powerful enough to classify data with accuracy comparable to RBF SVM.

Learning SVM with RBF kernel could be time-consuming. In order to be more expressive, we try to approximate nonlinear kernel, map vectors into higher dimensional space explicitly and use fast linear SVM in this new space. This works extremely well!

The script svm_mnist_embedings.py presents accuracy summary and training times for full RBF kernel, linear SVC, and linear SVC with two kernel approximation Nystroem and Fourier.

Further improvements

  • Augmenting the training set with artificial samples
  • Using Randomized param search

Useful SVM MNIST learning materials

More Repositories

1

awesome-nlp-polish

A curated list of resources dedicated to Natural Language Processing (NLP) in polish. Models, tools, datasets.
250
star
2

KMLib

Kernel Machine Library - fast GPU SVM in.net. Implemented kernels on CPU and GPU (Linear,RBF,Chi-Square,Exp Chi-Square). Library includes GPU SVM solvers for sparse problems.
Perl
27
star
3

tensorflow-mnist-convnets

Neural nets for MNIST classification, simple single layer NN, 5 layer FC NN and convolutional neural networks with different architectures
Python
22
star
4

decaptcha

Decoding capcha with convolution neural netowrk
Python
16
star
5

CudaDotProd

Different implementation of sparse matrix multiplication. All matrices are in CSR format. The code contains different CUDA kernels for multiply sparse matrix vs dense vector and sparse matrix vs another sparse matrix. It contains several cuda kernel for sparse matrix dense vector product and sparse matrix sparse matrix product.
C#
16
star
6

primal_svm

Python implementation of fast linear SVM
Python
13
star
7

pytorch_neural_networks

🔥 Pytorch neural network tutorial. Build: feedforward, convolutional, recurrent/LSTM neural network.
Python
13
star
8

seq2seq-attention-pytorch-lightning

Pytorch-Lightning Seq2seq model with the use of recurrent neural network
Python
10
star
9

pyKMLib

Python SVM with CUDA support.
Python
7
star
10

numbers_recognition

Convolutional network for number recognition
Python
5
star
11

HbaseStargate.net

.net client library for stargate (rest) interface to hbase
C#
5
star
12

blipface

WPF interface for polish twitter clone http://blip.pl , more on http://blipface.pl
C#
5
star
13

Pandas_Worldbank_GDP

World Bank GDP analysis of 10 European country with some plots
Python
5
star
14

scikit-learn-tutorial

Scikit-learn tutorial for beginniers. How to perform classification, regression. How to measure machine learning model performacne acuuracy, presiccion, recall, ROC.
Python
5
star
15

MorfeuszNet

.net wrapper for polish morfological analyzer from http://nlp.ipipan.waw.pl/~wolinski/morfeusz/
C#
4
star
16

numpy-tutorial

Python numpy tutorial based on stanford cs231
Python
3
star
17

numpy-pandas-tutorial

Introduction to numpy and pandas for data visualization in python.
Python
3
star
18

pytorch_fundamentals

Set of simple Pytorch scripts for better understanding pytorch fundamentals as tensors, loss functins, modules etc.
Python
3
star
19

image_convolution_example

Python
2
star
20

winmole

Windows app launcher
C#
1
star
21

hdf5_python_playground

HDF5 and h5py basic tutorial
Python
1
star
22

Matplotlib_examples

Introduction to matplotlib with examples
Python
1
star
23

char_rnn_for_learning

This is a repo which contains working scritpts with char rnn implementaions.
Python
1
star