• Stars
    star
    546
  • Rank 78,685 (Top 2 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created over 2 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.

Get started with JAX! 💻

The goal of this repo is to make it easier to get started with JAX, Flax, and Haiku!

JAX ecosystem is becoming an increasingly popular alternative to PyTorch and TensorFlow. 😎





Note: I'm only going to recommend content that I've personally analyzed and found useful here. If you want a comprehensive list check out the awesome-jax repo.

Table of Contents

My Machine Learning with JAX Tutorials

Tip on how to use notebooks: just open the notebook directly in Google Colab (you'll see a button on top of the Jupyter file which will direct you to Colab). This way you can avoid having to setup the Python env! (This was especially convenient for me since I'm on Windows which is still not supported)

Tutorial #1: From Zero to Hero

In this video, we start from the basics and then gradually dig into the nitty-gritty details of jit, grad, vmap, and various other idiosyncrasies of JAX.

YouTube Video (Tutorial #1)
Accompanying Jupyter Notebook

JAX from zero to hero!

Tutorial #2: From Hero to HeroPro+

In this video, we learn all additional components needed to train ML models (such as NNs) on multiple machines! We'll train a simple MLP model and we'll even train an ML model on 8 TPU cores!

YouTube Video (Tutorial #2)
Accompanying Jupyter Notebook

JAX from Hero to HeroPro+!

Tutorial #3: Building a Neural Network from Scratch

Watch me code a Neural Network from scratch! 🥳 In this 3rd video of the JAX tutorials series.

In this video, I build an MLP and train it as a classifier on MNIST using PyTorch's data loader (although it's trivial to use a more complex dataset) - all this in "pure" JAX (no Flax/Haiku/Optax).

I then do an additional analysis:

  • Visualize MLP's learned weights
  • Visualize embeddings of a batch of images using t-SNE
  • Finally, I analyze whether we have too many dead ReLU neurons in our network

YouTube Video (Tutorial #3)
Accompanying Jupyter Notebook (Note: I'll soon refactor it but I'll link the original)

Building a Neural Network from Scratch in pure JAX!


Tutorial #4: Machine Learning with Flax - From Zero to Hero

In this video, I cover everything you need to know to get started with Flax!

We cover init, apply, TrainState, etc. and other idiosyncrasies like the usage of mutable and rngs keywords.

YouTube Video (Tutorial #4)
Accompanying Jupyter Notebook

Flax from Zero to Hero!


Tutorial #5 (coming up): Machine Learning with Haiku - From Zero to Hero

todo

Other useful content

Aside from the official docs here are some resources that helped me.

Videos

Blogs

Acknowledgements

  • The notebooks were heavily inspired by the official JAX, Flax, and Haiku docs.

Citation

If you find this content useful, please cite the following:

@misc{Gordic2021GetStartedWithJAX,
  author = {Gordić, Aleksa},
  title = {Get started with JAX},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/gordicaleksa/get-started-with-JAX}},
}

Connect With Me

If you'd love to have some more AI-related content in your life 🤓, consider:

Licence

License: MIT

More Repositories

1

pytorch-GAT

My implementation of the original GAT paper (Veličković et al.). I've additionally included the playground.py file for visualizing the Cora dataset, GAT embeddings, an attention mechanism, and entropy histograms. I've supported both Cora (transductive) and PPI (inductive) examples!
Jupyter Notebook
2,253
star
2

pytorch-original-transformer

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Jupyter Notebook
880
star
3

pytorch-GANs

My implementation of various GAN (generative adversarial networks) architectures like vanilla GAN (Goodfellow et al.), cGAN (Mirza et al.), DCGAN (Radford et al.), etc.
Python
366
star
4

Open-NLLB

Effort to open-source NLLB checkpoints.
Python
364
star
5

pytorch-deepdream

PyTorch implementation of DeepDream algorithm (Mordvintsev et al.). Additionally I've included playground.py to help you better understand basic concepts behind the algo.
Jupyter Notebook
352
star
6

pytorch-neural-style-transfer

Reconstruction of the original paper on neural style transfer (Gatys et al.). I've additionally included reconstruction scripts which allow you to reconstruct only the content or the style of the image - for better understanding of how NST works.
Python
343
star
7

stable_diffusion_playground

Playing around with stable diffusion. Generated images are reproducible because I save the metadata and latent information. You can generate and then later interpolate between the images of your choice.
Python
203
star
8

pytorch-learn-reinforcement-learning

A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.
Python
140
star
9

pytorch-neural-style-transfer-johnson

Reconstruction of the fast neural style transfer (Johnson et al.). Some portions of the paper have been improved by the follow-up work like the instance normalization, etc. Checkout transformer_net.py's header for details.
Python
110
star
10

serbian-llm-eval

Serbian LLM Eval.
Python
81
star
11

pytorch-naive-video-neural-style-transfer

Create naive (no temporal loss) NST for videos with person segmentation. Just place your videos in data/, run and you get your stylized and segmented videos.
Python
73
star
12

OpenGemini

Effort to open-source 10.5 trillion parameter Gemini model.
17
star
13

gordicaleksa

GitHub's new feature: repo with the same name as your GitHub name initialized with README.md will show on your landing page!
12
star
14

digital-image-processing

Projects I did for the Digital Image Processing course on my university
MATLAB
7
star
15

streamlit_playground

Simple Streamlit app.
Python
4
star
16

Open-NLLB-stopes

A library for preparing data for machine translation research (monolingual preprocessing, bitext mining, etc.) for the Open-NLLB effort.
Python
3
star
17

MachineLearningMicrosoftPetnica

Problems I solved for Microsoft ML summer camp in Petnica, Serbia
C++
3
star
18

competitive_programming

Contains algorithms and snippets I found useful when solving problems for TopCoder, Google Code Jam etc.
C++
2
star
19

slovenian-llm-eval

Slovenian LLM Eval.
Python
2
star
20

MicrosoftBubbleCup2018

My solutions for Bubble Cup 2018
C++
1
star
21

.dotfiles

Configuration files for my vim editor, bash etc.
Shell
1
star
22

GoogleCodeJam2018

My solutions for Google Code Jam 2018
C++
1
star