• Stars
    star
    363
  • Rank 116,700 (Top 3 %)
  • Language
    Python
  • License
    MIT License
  • Created about 5 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Run PyTorch models in the browser using ONNX.js

Run PyTorch models in the browser using ONNX.js

Run PyTorch models in the browser with JavaScript by first converting your PyTorch model into the ONNX format and then loading that ONNX model in your website or app using ONNX.js. In the video tutorial below, I take you through this process using the demo example of a handwritten digit recognition model trained on the MNIST dataset.

Tutorial

https://www.youtube.com/watch?v=Vs730jsRgO8

Live Demo and Code Sandbox

Note: The model used in this demo is not very accurate, it will often misclassify digits. It's only meant to be used as a proof of concept. It's the same model that was used in PyTorch's MNIST example. You can find more accurate image classification models here: Papers With Code - Image Classification

The files in this repo (and a description of what they do)

├── degug_demo
│   ├── debug.html (A debug test to make sure the generated ONNX model works. 
│   │               Uses ONNX.js to load and run the generated ONNX model.)
│   │ 
│   └── onnx_model.onnx (A copy of the generated ONNX model that will be loaded
│                        for debugging.)
│
├── full_demo
│   ├── index.html (The full demo's HTML code.)
│   │ 
│   ├── onnx_model.onnx (A copy of the generated ONNX model. Used by script.js.)
│   │ 
│   ├── script.js (The full demos's JS code. Loads the onnx_model.onnx and 
│   │              predicts the drawn numbers.)
│   │ 
│   └── style.css (The full demo's CSS.)
│                            
├── convert_to_onnx.py (Converts a trained PyTorch model into an ONNX model.)
│
├── inference_mnist_model.py (The PyTorch model description. Used by
│                             convert_to_onnx.py to generate the ONNX model.)
│                             
├── inputs_batch_preview.png (A preview of a batch of augmented input data. 
│                             Generated by preview_mnist_dataset.py.)
│
├── onnx_model.py (The ONNX model generated by convert_to_onnx.py.)
│
├── preview_dataset.py (For testing out different types of data augmentation.)
│
├── pytorch_model.pt (The trained PyTorch model parameters. Generated by 
│                     train_mnist.model.py and used by convert_to_onnx.py to
│                     generate the ONNX model.)
│
└── train_mnist_model.pt (Trains the PyTorch model and saves the trained 
                          parameters as pytorch_model.pt.)

The benefits of running a model in the browser:

  • Faster inference times with smaller models.
  • Easy to host and scale (only static files).
  • Offline support.
  • User privacy (can keep the data on the device).

The benefits of using a backend server:

  • Faster load times (don't have to download the model).
  • Faster and consistent inference times with larger models (can take advantage of GPUs or other accelerators).
  • Model privacy (don't have to share your model if you want to keep it private).

License

MIT

More Repositories

1

thumbnail-rating-bar-for-youtube

A Chrome and Firefox extension for YouTube that adds a rating bar (likes/dislikes ratio) to the bottom of every thumbnail.
JavaScript
249
star
2

rule-30-and-game-of-life

Generates a 2D animation of Rule 30 (or other rules) being fed into Conway's Game of Life.
Python
163
star
3

hammerspoon-config

My Hammerspoon configuration files.
Lua
40
star
4

pytorch-hooks-tutorial

Examples of using PyTorch hooks, as covered in my YouTube tutorial video.
Python
29
star
5

pycairo-animations

Pycairo Animation Library
Python
21
star
6

jetbrains-godot-theme

Godot Theme for JetBrains IDEs (PyCharm, IntelliJ, etc.)
GDScript
16
star
7

jetbrains-cyberpunk-theme

A theme for JetBrains IDEs (PyCharm, IntelliJ, etc.) inspired by Cyberpunk 2077.
Python
16
star
8

times-table-animation

A script for generating "Times Table" animations using Python, Numpy, and PyCairo.
Python
14
star
9

softmax-logit-paths

Plots how the logit values that are passed into the softmax function change over time as the model is trained.
Python
7
star
10

my-setup

All the apps and settings I use.
Shell
4
star
11

nim-opengl-tutorials-by-the-cherno

The Cherno's OpenGL Tutorials written in Nim.
Nim
4
star
12

cleantube

A Chrome and Firefox extension that declutters YouTube.
HTML
2
star
13

timestream

Timestream Timer App
2
star
14

udemy-faster-autoplay

A Chrome and Firefox extension for Udemy that removes the 3-second delay before autoplaying the next video in a playlist.
Python
2
star
15

object-detection-website-template

A website template for running object detection in the browser.
JavaScript
1
star