• Stars
    star
    119
  • Rank 298,034 (Top 6 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 4 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax

EfficientNet JAX - Flax Linen and Objax

Acknowledgements

Verification of training code was made possible with Cloud TPUs via Google's TPU Research Cloud (TRC) (https://www.tensorflow.org/tfrc)

Intro

This is very much a giant steaming work in progress. Jax, jaxlib, and the NN libraries I'm using are shifting week to week.

This code base currently supports:

This is essentially an adaptation of my PyTorch EfficienNet generator code (https://github.com/rwightman/gen-efficientnet-pytorch and also found in https://github.com/rwightman/pytorch-image-models) to JAX.

I started this to

  • learn JAX by working with familiar code / models as a starting point,
  • figure out which JAX modelling interface libraries ('frameworks') I liked,
  • compare the training / inference runtime traits of non-trivial models across combinations of PyTorch, JAX, GPU and TPU in order to drive cost optimizations for scaling up of future projects

Where are we at:

  • Training works on single node, multi-GPU and TPU v3-8 for Flax Linen variants w/ Tensorflow Datasets based pipeline
  • The Objax and Flax Linen (nn.compact) variants of models are working (for inference)
  • Weights are ported from PyTorch (my timm training) and Tensorflow (original paper author releases) and are organized in zoo of sorts (borrowed PyTorch code)
  • Tensorflow and PyTorch data pipeline based validation scripts work with models and weights. For PT pipeline with PT models and TF pipeline with TF models the results are pretty much exact.

TODO:

  • Fix model weight inits (working for Flax Linen variants)
  • Fix dropout/drop path impl and other training specifics (verified for Flax Linen variants)
  • Add more instructions / help in the README on how to get an optimal environment with JAX up and running (with GPU support)
  • Add basic training code. The main point of this is to scale up training.
  • Add more advance data augmentation pipeline
  • Training on lots of GPUs
  • Training on lots of TPUs

Some odd things:

  • Objax layers are reimplemented to make my initial work easier, scratch some itches, make more consistent with PyTorch (because why not?)
  • Flax Linen layers are by default fairly consistent with Tensorflow (left as is)
  • I use wrappers around Flax Linen layers for some argument consistency and reduced visual noise (no redundant tuples)
  • I made a 'LIKE' padding mode, sort of like 'SAME' but different, hence the name. It calculates symmetric padding for PyTorch models.
  • Models with Tensorflow 'SAME' padding and TF origin weights are prefixed with tf_. Models with PyTorch trained weights and symmetric PyTorch style padding ('LIKE' here) are prefixed with pt_
  • I use pt and tf to refer to PyTorch and Tensorflow for both the models and environments. These two do not need to be used together. pt models with 'LIKE' padding will work fine running in a Tensorflow based environment and vice versa. I did this to show the full flexibility here, that one can use JAX models with PyTorch data pipelines and datasets or with Tensorflow based data pipelines and TFDS.

Models

Supported models and their paper's

Models by their config name w/ valid pretrained weights that should be working here:

pt_mnasnet_100
pt_semnasnet_100
pt_mobilenetv2_100
pt_mobilenetv2_110d
pt_mobilenetv2_120d
pt_mobilenetv2_140
pt_fbnetc_100
pt_spnasnet_100
pt_efficientnet_b0
pt_efficientnet_b1
pt_efficientnet_b2
pt_efficientnet_b3
tf_efficientnet_b0
tf_efficientnet_b1
tf_efficientnet_b2
tf_efficientnet_b3
tf_efficientnet_b4
tf_efficientnet_b5
tf_efficientnet_b6
tf_efficientnet_b7
tf_efficientnet_b8
tf_efficientnet_b0_ap
tf_efficientnet_b1_ap
tf_efficientnet_b2_ap
tf_efficientnet_b3_ap
tf_efficientnet_b4_ap
tf_efficientnet_b5_ap
tf_efficientnet_b6_ap
tf_efficientnet_b7_ap
tf_efficientnet_b8_ap
tf_efficientnet_b0_ns
tf_efficientnet_b1_ns
tf_efficientnet_b2_ns
tf_efficientnet_b3_ns
tf_efficientnet_b4_ns
tf_efficientnet_b5_ns
tf_efficientnet_b6_ns
tf_efficientnet_b7_ns
tf_efficientnet_l2_ns_475
tf_efficientnet_l2_ns
pt_efficientnet_es
pt_efficientnet_em
tf_efficientnet_es
tf_efficientnet_em
tf_efficientnet_el
pt_efficientnet_lite0
tf_efficientnet_lite0
tf_efficientnet_lite1
tf_efficientnet_lite2
tf_efficientnet_lite3
tf_efficientnet_lite4
pt_mixnet_s
pt_mixnet_m
pt_mixnet_l
pt_mixnet_xl
tf_mixnet_s
tf_mixnet_m
tf_mixnet_l
pt_mobilenetv3_large_100
tf_mobilenetv3_large_075
tf_mobilenetv3_large_100
tf_mobilenetv3_large_minimal_100
tf_mobilenetv3_small_075
tf_mobilenetv3_small_100
tf_mobilenetv3_small_minimal_100

Environment

Working with JAX I've found the best approach for having a working GPU compatible environment that performs well is to use Docker containers based on the latest NVIDIA NGC releases. I've found it challenging or flaky getting local conda/pip venvs or Tensorflow docker containers working well with good GPU performance, proper NCCL distributed support, etc. I use CPU JAX install in conda env for dev/debugging.

Dockerfiles

There are several container definitions in docker/. They use NGC containers as their parent image so you'll need to be setup to pull NGC containers: https://www.nvidia.com/en-us/gpu-cloud/containers/ . I'm currently using recent NGC containers w/ CUDA 11.1 support, the host system will need a very recent NVIDIA driver to support this but doesn't need a matching CUDA 11.1 / cuDNN 8 install.

Current dockerfiles:

  • pt_git.Dockerfile - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.
  • pt_pip.Dockerfile - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.
  • tf_git.Dockerfile - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.
  • tf_pip.Dockerfile - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.

The 'git' containers take some time to build jaxlib, they pull the masters of all respective repos so are up to the bleeding edge but more likely to have possible regression or incompatibilities that go with that. The pip install containers are quite a bit quicker to get up and running, based on the latest pip versions of all repos.

Docker Usage (GPU)

  1. Make sure you have a recent version of docker and the NVIDIA Container Toolkit setup (https://github.com/NVIDIA/nvidia-docker)
  2. Build the container docker build -f docker/tf_pip.Dockerfile -t jax_tf_pip .
  3. Run the container, ideally map jeffnet and datasets (ImageNet) into the container
    • For tf containers, docker run --gpus all -it -v /path/to/tfds/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_tf_pip
    • For pt containers, docker run --gpus all -it -v /path/to/imagenet/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_pt_pip
  4. Model validation w/ pretrained weights (once inside running container):
    • For tf, in worskpace/jeffnet, python tf_linen_validate.py /data/ --model tf_efficientnet_b0_ns
    • For pt, in worskpace/jeffnet, python pt_objax_validate.py /data/validation --model pt_efficientnet_b0
  5. Training (within container)
    • In worskpace/jeffnet, tf_linen_train.py --config train_configs/tf_efficientnet_b0-gpu_24gb_x2.py --config.data_dir /data

TPU

I've successfully used this codebase on TPU VM environments as is. Any of the tpu_x8 training configs should work out of the box on a v3-8 TPU. I have not tackled training with TPU Pods.

More Repositories

1

gen-efficientnet-pytorch

Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS
Python
1,550
star
2

efficientdet-pytorch

A PyTorch impl of EfficientDet faithful to the original Google impl w/ ported weights
Python
1,528
star
3

posenet-python

A Python port of Google TensorFlow.js PoseNet (Real-time Human Pose Estimation)
Python
482
star
4

posenet-pytorch

A PyTorch port of Google TensorFlow.js PoseNet (Real-time Human Pose Estimation)
Python
285
star
5

pytorch-dpn-pretrained

Dual Path Networks (DPN) supporting pretrained weights converted from original MXNet implementation
Python
206
star
6

udacity-driving-reader

Quick docker based reader for udacity rosbag self-driving dataset. Dumps to png/jpg + csv or Tensorflow examples.
Python
163
star
7

pytorch-nips2017-attack-example

A PyTorch baseline attack example for the NIPS 2017 adversarial competition
Python
83
star
8

tensorflow-litterbox

Tensorflow experimentation sandbox. VGG, ResNet, Inception V3, Inception V4, Inception-Resnet-V2 models.
Python
40
star
9

pytorch-pommerman-rl

PyTorch RL for Pommerman
Python
38
star
10

pytorch-planet-amazon

PyTorch models and training code for 'Planet: Understanding the Amazon from Space' Kaggle
Python
24
star
11

pytorch-pretrained-gluonresnet

Well trained MXNet Gluon Model Zoo ResNet/ResNeXt/SE-ResNeXt ported to PyTorch
Python
19
star
12

imagenet-12k

ImageNet-12k subset of ImageNet-21k (fall11)
Python
15
star
13

pytorch-nips2017-adversarial

NIPS 2017 Adversarial Competition in PyTorch
Python
13
star
14

pytorch-opensim-rl

PyTorch based Reinforcement Learning for OpenSim Prosthetics and Learning to Run environments
Python
11
star
15

pytorch-commands

Some PyTorch code for the Kaggle Speech Recognition Challenge
Python
11
star
16

pytorch-nips2017-defense-example

A PyTorch baseline defense example for the NIPS 2017 adversarial competition
Python
10
star
17

pytorch-countception-sealion

Pytorch implementation of Count-ception and custom CNN counting models for Kaggle Sea Lion Count challenge
Python
10
star
18

tensorflow-annex

A module for converting datasets + annotations to sharded Tensorflow TFRecords files
Python
7
star
19

tensorflow-speech_commands

Speech commands training/models from TF repo adapted for speech commands Kaggle
Python
6
star
20

obstacle-tower-pytorch-rainbow

PyTorch & Rainbow for Obstacle Tower Challenge
Python
4
star
21

tensorflow-models-slim

Filtered research/slim folder from tensorflow models repository. With experiments and additions.
Python
3
star
22

boost

My boost library mirror
C++
2
star
23

ariac-workspace

Docker based setup for experimenting with ARIAC ROS setup
Shell
2
star
24

obstacle-tower-pytorch-a2c-ppo

PPO/A2C in PyTorch for the Obstacle Tower Challenge
Python
2
star
25

pytorch-docker

Dockerfile definitions with matching Docker Hub images for PyTorch
Python
2
star
26

asio

Asio C++ Library
C++
1
star
27

androids-dream

HTML
1
star
28

pytorch-cdiscount

Python
1
star