• Stars
    star
    306
  • Rank 136,416 (Top 3 %)
  • Language
    Jupyter Notebook
  • License
    Apache License 2.0
  • Created over 2 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Probing the representations of Vision Transformers.

Probing ViTs

TensorFlow 2.8 HugginFace badge

By Aritra Roy Gosthipaty and Sayak Paul (equal contribution)

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, self-supervised pre-training):

  • Original ViT [1]
  • DeiT [2]
  • DINO [3]

We hope these tools will prove to be useful for the community. Please follow along with this post on keras.io for a better navigation through the repository.

Updates

Self-attention visualization

Original Image Attention Maps Attention Maps Overlayed
original image attention maps attention maps overlay
output-dino.mp4

Original Video Source

output-dog.mp4

Original Video Source

Supervised salient representations

In the DINO blog post, the authors show a video with the following caption:

The original video is shown on the left. In the middle is a segmentation example generated by a supervised model, and on the right is one generated by DINO.

A screenshot of the video is as follows:

image

We obtain the attention maps generated with the supervised pre-trained model and find that they are not that salient w.r.t the DINO model. We observe a similar behaviour in our experiments as well. The figure below shows the attention heatmaps extracted with a ViT-B16 model pre-trained (supervised) using ImageNet-1k:

Dinosaur Dog

We used this Colab Notebook to conduct this experiment.

Hugging Face Spaces

You can now probe into the ViTs with your own input images.

Attention Heat Maps Attention Rollout
Generic badge Generic badge

Visualizing mean attention distances

Methods

We don't propose any novel methods of probing the representations of neural networks. Instead we take the existing works and implement them in TensorFlow.

  • Mean attention distance [1, 4]
  • Attention Rollout [5]
  • Visualization of the learned projection filters [1]
  • Visualization of the learned positioanl embeddings
  • Attention maps from individual attention heads [3]
  • Generation of attention heatmaps from videos [3]

Another interesting repository that also visualizes ViTs in PyTorch: https://github.com/jacobgil/vit-explain.

Notes

We first implemented the above-mentioned architectures in TensorFlow and then we populated the pre-trained parameters into them using the official codebases. In order to validate this, we evaluated the implementations on the ImageNet-1k validation set and ensured that the reported top-1 accuracies matched.

We value the spirit of open-source. So, if you spot any bugs in the code or see a scope for improvement don't hesitate to open up an issue or contribute a PR. We'd very much appreciate it.

Navigating through the codebase

Our ViT implementations are in vit. We provide utility notebooks in the notebooks directory which contains the following:

DeiT-related code has its separate repository: https://github.com/sayakpaul/deit-tf.

Models

Here are the links to the models where the pre-trained parameters were populated:

Training and visualizing with small datasets

Coming soon!

References

[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale: https://arxiv.org/abs/2010.11929

[2] DeiT: https://arxiv.org/abs/2012.12877

[3] DINO: https://arxiv.org/abs/2104.14294

[4] Do Vision Transformers See Like Convolutional Neural Networks?: https://arxiv.org/abs/2108.08810

[5] Quantifying Attention Flow in Transformers: https://arxiv.org/abs/2005.00928

Acknowledgements

More Repositories

1

TF-2.0-Hacks

Contains my explorations of TensorFlow 2.x
Jupyter Notebook
383
star
2

diffusers-torchao

End-to-end recipes for optimizing diffusion models with torchao and diffusers (inference and FP8 training).
Python
231
star
3

ml-deployment-k8s-fastapi

This project shows how to serve an ONNX-optimized image classification model as a web service with FastAPI, Docker, and Kubernetes.
Jupyter Notebook
191
star
4

Adventures-in-TensorFlow-Lite

This repository contains notebooks that show the usage of TensorFlow Lite for quantizing deep neural networks.
Jupyter Notebook
146
star
5

maxim-tf

Implementation of MAXIM in TensorFlow.
Jupyter Notebook
130
star
6

Supervised-Contrastive-Learning-in-TensorFlow-2

Implements the ideas presented in https://arxiv.org/pdf/2004.11362v1.pdf by Khosla et al.
Jupyter Notebook
130
star
7

robustness-vit

Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
Jupyter Notebook
122
star
8

MIRNet-TFLite-TRT

TensorFlow Lite models for MIRNet for low-light image enhancement.
Jupyter Notebook
113
star
9

ConvNeXt-TF

Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.
Jupyter Notebook
98
star
10

SimSiam-TF

Minimal implementation of SimSiam (https://arxiv.org/abs/2011.10566) in TensorFlow 2.
Jupyter Notebook
96
star
11

dreambooth-keras

Implementation of DreamBooth in KerasCV and TensorFlow.
Jupyter Notebook
87
star
12

SimCLR-in-TensorFlow-2

(Minimally) implements SimCLR (https://arxiv.org/abs/2002.05709) in TensorFlow 2.
Jupyter Notebook
83
star
13

single-video-curation-svd

Educational repository for applying the main video data curation techniques presented in the Stable Video Diffusion paper.
Jupyter Notebook
80
star
14

Adaptive-Gradient-Clipping

Minimal implementation of adaptive gradient clipping (https://arxiv.org/abs/2102.06171) in TensorFlow 2.
Jupyter Notebook
75
star
15

FunMatch-Distillation

TF2 implementation of knowledge distillation using the "function matching" hypothesis from https://arxiv.org/abs/2106.05237.
Jupyter Notebook
74
star
16

Grocery-Product-Detection

This repository builds a product detection model to recognize products from grocery shelf images.
Jupyter Notebook
73
star
17

stable-diffusion-keras-ft

Fine-tuning Stable Diffusion using Keras.
Python
60
star
18

robustness-foundation-models

This repository holds code and other relevant files for the NeurIPS 2022 tutorial: Foundational Robustness of Foundation Models.
Jupyter Notebook
55
star
19

Sharpness-Aware-Minimization-TensorFlow

Implements sharpness-aware minimization (https://arxiv.org/abs/2010.01412) in TensorFlow 2.
Jupyter Notebook
55
star
20

MLP-Mixer-CIFAR10

Implements MLP-Mixer (https://arxiv.org/abs/2105.01601) with the CIFAR-10 dataset.
Jupyter Notebook
52
star
21

ML-Bootcamp-Launchpad

Contains notebooks prepared for ML Bootcamp organized by Google Developers Launchpad.
Jupyter Notebook
52
star
22

Learnable-Image-Resizing

TF 2 implementation Learning to Resize Images for Computer Vision Tasks (https://arxiv.org/abs/2103.09950v1).
Jupyter Notebook
48
star
23

Knowledge-Distillation-in-Keras

Demonstrates knowledge distillation for image-based models in Keras.
Jupyter Notebook
46
star
24

PAWS-TF

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Jupyter Notebook
43
star
25

swin-transformers-tf

Implementation of Swin Transformers in TensorFlow along with converted pre-trained models, code for off-the-shelf classification and fine-tuning.
Jupyter Notebook
41
star
26

Generalized-ODIN-TF

TensorFlow 2 implementation of the paper Generalized ODIN: Detecting Out-of-distribution Image without Learning from Out-of-distribution Data (https://arxiv.org/abs/2002.11297).
Jupyter Notebook
39
star
27

hf-codegen

A repository of Python scripts to scrape code contents of the public repositories of `huggingface`.
Python
37
star
28

A-B-testing-with-Machine-Learning

Implemented an A/B Testing solution with the help of machine learning
Jupyter Notebook
37
star
29

deit-tf

Includes PyTorch -> Keras model porting code for DeiT models with fine-tuning and inference notebooks.
Jupyter Notebook
37
star
30

Barlow-Twins-TF

TensorFlow implementation of Barlow Twins (https://arxiv.org/abs/2103.03230).
Jupyter Notebook
37
star
31

Vehicle-Number-Plate-Detection

This project demonstrates the use of TensorFlow Object Detection API to automatically number plates (Indian) from vehicles
Jupyter Notebook
36
star
32

Benchmarking-and-MLI-experiments-on-the-Adult-dataset

Contains benchmarking and interpretability experiments on the Adult dataset using several libraries
Jupyter Notebook
35
star
33

Dual-Deployments-on-Vertex-AI

Project demonstrating dual model deployment scenarios using Vertex AI (GCP).
Jupyter Notebook
35
star
34

ViT-jax2tf

This repository hosts code for converting the original Vision Transformer models (JAX) to TensorFlow.
Jupyter Notebook
34
star
35

Breast-Cancer-Detection-using-Deep-Learning

Experiments to show the usage of deep learning to detect breast cancer from breast histopathology images
Jupyter Notebook
32
star
36

Spatial-Transformer-Networks-with-Keras

This repository provides a Colab Notebook that shows how to use Spatial Transformer Networks inside CNNs in Keras.
Jupyter Notebook
31
star
37

deploy-hf-tf-vision-models

This repository shows various ways of deploying a vision model (TensorFlow) from 🤗 Transformers.
Jupyter Notebook
29
star
38

E2E-Object-Detection-in-TFLite

This repository shows how to train a custom detection model with the TFOD API, optimize it with TFLite, and perform inference with the optimized model.
Jupyter Notebook
28
star
39

A-Barebones-Image-Retrieval-System

This project presents a simple framework to retrieve images similar to a query image.
Jupyter Notebook
25
star
40

Training-BatchNorm-and-Only-BatchNorm

Experiments with the ideas presented in https://arxiv.org/abs/2003.00152 by Frankle et al.
Jupyter Notebook
24
star
41

ml-bootcamp-india-2022

Contains materials from the facilitation sessions conducted for the ML Bootcamp India (2022) organized by Google DevRel team.
Jupyter Notebook
22
star
42

BERT-for-Mobile

Compares the DistilBERT and MobileBERT architectures for mobile deployments.
Jupyter Notebook
21
star
43

Blood-Cell-Detection-using-TFOD-API

This project demonstrates the use of TensorFlow Object Detection API (along with GCP ML Engine) to automatically detect Red Blood Cells (RBCs), White Blood Cells (WBCs), and Platelets in each image taken via microscopic image readings
Jupyter Notebook
21
star
44

CI-CD-for-Model-Training

This repository holds files and scripts for incorporating simple CI/CD practices for model training in ML.
Jupyter Notebook
20
star
45

big_vision_experiments

Contains my experiments with the `big_vision` repo to train ViTs on ImageNet-1k.
Jupyter Notebook
20
star
46

Denoised-Smoothing-TF

Minimal implementation of Denoised Smoothing (https://arxiv.org/abs/2003.01908) in TensorFlow.
Jupyter Notebook
20
star
47

portfolio

Site built from fastpages: https://fastpages.fast.ai/. Deployed here 👉
Jupyter Notebook
19
star
48

Global-Wheat-Detection

Showcases the use of deep learning to detect wheat heads from crops. The project is based on: https://www.kaggle.com/c/global-wheat-detection.
Jupyter Notebook
19
star
49

Distributed-Training-in-TensorFlow-2-with-AI-Platform

Contains code to demonstrate distributed training in TensorFlow 2 with AI Platform and custom Docker contains.
Python
17
star
50

instructpix2pix-sdxl

Training InstructPi2Pix with SDXL.
Python
17
star
51

count-tokens-hf-datasets

This project shows how to derive the total number of training tokens from a large text dataset from 🤗 datasets with Apache Beam and Dataflow.
Python
15
star
52

FloydHub-Anomaly-Detection-Blog

Contains the thorough experiments made for a FloydHub article on Anomaly Detection
Jupyter Notebook
15
star
53

Phishing-Websites-Detection

Experiments to detect phishing websites using neural networks
Jupyter Notebook
15
star
54

Malaria-Detection-with-Deep-Learning

Deep learning based solution to automatically analyze medical images for malaria testing
Jupyter Notebook
13
star
55

BiT-jax2tf

This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.
Jupyter Notebook
13
star
56

Parkinson-s-Disease-Classifier

Deep learning experiments to design a model to predict Parkinson's diseases with the images of Spiral/Wave test
Jupyter Notebook
13
star
57

instruct-pix2pix-dataset

This repository provides utilities to a minimal dataset for InstructPix2Pix like training for Diffusion models.
Python
13
star
58

MLPMixer-jax2tf

This repository hosts code for converting the original MLP Mixer models (JAX) to TensorFlow.
Jupyter Notebook
13
star
59

Generating-categories-from-arXiv-paper-titles

This project takes the arXiv dataset and builds an automatic tag classifier from the arXiv article/paper titles
Jupyter Notebook
13
star
60

Handwriting-Recognizer-in-Keras

This project shows how to build a simple handwriting recognizer in Keras with the IAM dataset.
Jupyter Notebook
12
star
61

Mixed-Precision-Training-in-tf.keras-2.0

This repository contains notebooks showing how to perform mixed precision training in tf.keras 2.0
Jupyter Notebook
12
star
62

NALU

Neural Arithmetic Logic Units by Trask et al.
Jupyter Notebook
12
star
63

Adversarial-Examples-in-Deep-Learning

Shows how to create basic image adversaries, and train adversarially robust image classifiers (to some extent).
Jupyter Notebook
12
star
64

FloydHub-Q-Learning-Blog

Contains the Jupyter Notebook made for a FloydHub article on Q-Learning
Jupyter Notebook
11
star
65

EvoNorms-in-TensorFlow-2

Implements EvoNorms B0 and S0 as proposed in Evolving Normalization-Activation Layers.
Jupyter Notebook
11
star
66

Revisiting-Pooling-in-CNNs

Implements RNNPool and SoftPool for CNNs.
Jupyter Notebook
11
star
67

parameter-ensemble-differential-evolution

Shows how to do parameter ensembling using differential evolution.
Jupyter Notebook
10
star
68

Action-Recognition-in-TensorFlow

Contains additional materials for two keras.io blog posts.
Jupyter Notebook
10
star
69

AdaMatch-TF

Includes additional materials for the following keras.io blog post.
Jupyter Notebook
10
star
70

keras-convnext-conversion

ConvNeXt conversion code for PT to TF along with evaluation code on ImageNet-1k val.
Python
10
star
71

Emotion-Detection-using-Deep-Learning

This project demonstrates the use of Deep Learning to detect emotion (sad, angry, happy etc) from the images of faces.
Jupyter Notebook
10
star
72

Transfer-Learning-with-CIFAR10

Leveraging Transfer Learning on the classic CIFAR-10 dataset by using the weights from a pre-trained VGG-16 model.
Jupyter Notebook
9
star
73

TalksGiven

Contains the deck of my talks given at different developer meet-ups and conferences.
9
star
74

cait-tf

Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.
Jupyter Notebook
9
star
75

GCP-ML-API-Demos

Contains Colab Notebooks show cool use-cases of different GCP ML APIs.
Jupyter Notebook
9
star
76

DockerML

Contains my explorations of using Docker to automate ML workflows.
Python
8
star
77

Multimodal-Entailment-Baseline

This repository shows how to implement a basic model for multimodal entailment.
Jupyter Notebook
8
star
78

vision-transformers-tf

A non-exhaustive collection of vision transformer models implemented in TensorFlow.
7
star
79

xla-benchmark-sd

Provides code to serialize the different models involved in Stable Diffusion as SavedModels and to compile them with XLA.
Python
7
star
80

tf.keras-Distributed-Training

Shows how to use MirroredStrategy to distribute training workloads when using the regular fit and compile paradigm in tf.keras.
Jupyter Notebook
7
star
81

Analytics-Vidhya-Game-of-Deep-Learning-Hackathon

Contains my experiments for the Game of Deep Learning Hackathon conducted by Analytics Vidhya
Jupyter Notebook
7
star
82

model-soups-text-classification

Shows an implementation of model soups (https://arxiv.org/abs/2203.05482) for text classification models.
Jupyter Notebook
6
star
83

keras-xla-benchmarks

Presents comprehensive benchmarks of XLA-compatible pre-trained models in Keras.
Python
5
star
84

Manning-Phishing-Websites-Detection

Starter repository for Manning LP: Use Machine Learning to Detect Phishing Websites
5
star
85

Data-Science-Tweets-Analysis

Byte sized analysis of Data Science Tweets
Jupyter Notebook
5
star
86

Data-Pipelines-with-TensorFlow-Data-Services-Exercises

This repository contains the exercise notebooks for the Data Pipelines with TensorFlow Data Services (Coursera) course.
Jupyter Notebook
5
star
87

Applied-Data-Science-w-Python-Specialization

Contains my assignments, guiding notebooks (provided as the course materials) and the datasets.
Jupyter Notebook
5
star
88

Reproducibility-in-tf.keras-with-wandb

Contains code for my model reproducibility experiments. Report attached.
Jupyter Notebook
5
star
89

GDE-Collaboration

Contains the modelling and deployment scripts done for a collaboration project (Dataset: FashionMNIST)
Jupyter Notebook
4
star
90

simple-image-recaptioning

Recaption large (Web)Datasets with vllm and save the artifacts.
Python
4
star
91

Consistency-Training-with-Supervision

Contains experimentation notebooks for my Keras Example "Consistency Training with Supervision".
Jupyter Notebook
4
star
92

mlplanner

Contains data, notebooks and other files of FloydHub's mini-series on machine learning project structuring, model debugging, various tips and tricks and more
Jupyter Notebook
4
star
93

tSNE-on-NSL_KDD

Jupyter Notebook
4
star
94

Generating-Word-Cloud-from-DataCamp-Project-Descriptions

Generating a word cloud from the descriptions of the live [DataCamp Projects](https://datacamp.com/projects/).
Jupyter Notebook
4
star
95

Analysis-of-college-database-of-2017-passouts

Contains my analysis of a database containing information about the students of an engineering college.
Jupyter Notebook
4
star
96

Age-Detection-of-Indian-Actors

The repository contains my experiments for a Hackathon problem by Analytics Vidhya
Jupyter Notebook
4
star
97

depth_estimation_trainer

Scripts to fine-tune a depth estimation model.
Python
4
star
98

Restaurant

It contains everything that should be there in a standard restaurant website such as Location finder, Chaining, Bill Estimator etc. Built using JSP, jQuery, JavaScript, HTML-CSS and for databases Oracle 10g XE was used.
HTML
3
star
99

Reporting-in-SQL-DataCamp-Course-Exercises

Contains my solutions to the exercises of the course "Reporting in SQL"
3
star
100

DataCamp-blogs

Jupyter notebooks of my DataCamp blogs
Jupyter Notebook
3
star