• Stars
    star
    188
  • Rank 205,563 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created about 3 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Video Swin Transformer - PyTorch

Video-Swin-Transformer-Pytorch

This repo is a simple usage of the official implementation "Video Swin Transformer".

teaser

Introduction

Video Swin Transformer is initially described in "Video Swin Transformer", which advocates an inductive bias of locality in video Transformers, leading to a better speed-accuracy trade-off compared to previous approaches which compute self-attention globally even with spatial-temporal factorization. The locality of the proposed video architecture is realized by adapting the Swin Transformer designed for the image domain, while continuing to leverage the power of pre-trained image models. Our approach achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (84.9 top-1 accuracy on Kinetics-400 and 86.1 top-1 accuracy on Kinetics-600 with ~20x less pre-training data and ~3x smaller model size) and temporal modeling (69.6 top-1 accuracy on Something-Something v2).

Usage

Installation

$ pip install -r requirements.txt

Prepare

$ git clone https://github.com/haofanwang/video-swin-transformer-pytorch.git
$ cd video-swin-transformer-pytorch
$ mkdir checkpoints && cd checkpoints
$ wget https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_base_patch244_window1677_sthv2.pth
$ cd ..

Please refer to Video-Swin-Transformer and download other checkpoints.

Inference

import torch
import torch.nn as nn
from video_swin_transformer import SwinTransformer3D

model = SwinTransformer3D()
print(model)

dummy_x = torch.rand(1, 3, 32, 224, 224)
logits = model(dummy_x)
print(logits.shape)

If you want to utilize the pre-trained checkpoints without diving into the codebase of open-mmlab, you can also do it as below.

import torch
import torch.nn as nn
from collections import OrderedDict
from video_swin_transformer import SwinTransformer3D

model = SwinTransformer3D(embed_dim=128, 
                          depths=[2, 2, 18, 2], 
                          num_heads=[4, 8, 16, 32], 
                          patch_size=(2,4,4), 
                          window_size=(16,7,7), 
                          drop_path_rate=0.4, 
                          patch_norm=True)

# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/configs/recognition/swin/swin_base_patch244_window1677_sthv2.py
checkpoint = torch.load('./checkpoints/swin_base_patch244_window1677_sthv2.pth')

new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    if 'backbone' in k:
        name = k[9:]
        new_state_dict[name] = v 

model.load_state_dict(new_state_dict) 

dummy_x = torch.rand(1, 3, 32, 224, 224)
logits = model(dummy_x)
print(logits.shape)

Warning: this is an informal implementation, and there may be errors that are difficult to find. Therefore, I strongly recommend that you use the official code base to load the weights.

Inference as official

$ git clone https://github.com/SwinTransformer/Video-Swin-Transformer.git
$ cp *.py Video-Swin-Transformer
$ cd Video-Swin-Transformer

Then, you can load the pre-trained checkpoint.

from mmcv import Config, DictAction
from mmaction.models import build_model
from mmcv.runner import get_dist_info, init_dist, load_checkpoint

config = './configs/recognition/swin/swin_base_patch244_window1677_sthv2.py'
checkpoint = './checkpoints/swin_base_patch244_window1677_sthv2.pth'

cfg = Config.fromfile(config)
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
load_checkpoint(model, checkpoint, map_location='cpu')

# [batch_size, channel, temporal_dim, height, width]
dummy_x = torch.rand(1, 3, 32, 224, 224)

# SwinTransformer3D without cls_head
backbone = model.backbone

# [batch_size, hidden_dim, temporal_dim/2, height/32, width/32]
feat = backbone(dummy_x)

# alternative way
feat = model.extract_feat(dummy_x)

# mean pooling
feat = feat.mean(dim=[2,3,4]) # [batch_size, hidden_dim]

# project
batch_size, hidden_dim = feat.shape
feat_dim = 512
proj = nn.Parameter(torch.randn(hidden_dim, feat_dim))

# final output
output = feat @ proj # [batch_size, feat_dim]

Acknowledgement

The code is adapted from the official Video-Swin-Transformer repository. This project is inspired by swin-transformer-pytorch, which provides the simplest code to get started.

Citation

If you find our work useful in your research, please cite:

@article{liu2021video,
  title={Video Swin Transformer},
  author={Liu, Ze and Ning, Jia and Cao, Yue and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Hu, Han},
  journal={arXiv preprint arXiv:2106.13230},
  year={2021}
}

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}

More Repositories

1

ControlNet-for-Diffusers

Transfer the ControlNet with any basemodel in diffusers🔥
Python
743
star
2

Lora-for-Diffusers

The most easy-to-understand tutorial for using LoRA (Low-Rank Adaptation) within diffusers framework for AI Generation Researchers🔥
Python
684
star
3

Score-CAM

Official implementation of Score-CAM in PyTorch
Python
379
star
4

inswapper

One-click Face Swapper and Restoration powered by insightface 🔥
Python
327
star
5

awesome-conditional-content-generation

Update-to-data resources for conditional content generation, including human motion generation, image or video generation and editing.
212
star
6

Awesome-Computer-Vision

Awesome Resources for Advanced Computer Vision Topics
209
star
7

natural-language-joint-query-search

Search photos on Unsplash based on OpenAI's CLIP model, support search with joint image+text queries and attention visualization.
Jupyter Notebook
184
star
8

T2I-Adapter-for-Diffusers

Transfer the T2I-Adapter with any basemodel in diffusers🔥
125
star
9

CLIFF

This repo equips the official CLIFF [ECCV 2022 Oral] with better detector, better tracker. Support multi-person, motion interpolation, motion smooth and SMPLify fitting.
Python
113
star
10

awesome-mlp-papers

Recent Advances in MLP-based Models (MLP is all you need!)
110
star
11

accurate-head-pose

Pytorch code for Hybrid Coarse-fine Classification for Head Pose Estimation
Python
97
star
12

Train-ControlNet-in-Diffusers

We show you how to train a ControlNet with your own control hint in diffusers framework
52
star
13

mxnet-Head-Pose

An MXNet implementation of Fine-Grained Head Pose
Python
47
star
14

cropimage

A simple toolkit for detecting and cropping main body from pictures. Support face and saliency detection.
Python
34
star
15

awesome-vision-language-modeling

Recent Advances in Vision-Language Pre-training!
25
star
16

visbeat3

Python3 Implementation for 'Visual Rhythm and Beat' SIGGRAPH 2018
Python
16
star
17

DWPose

Inference code for DWCode
Python
15
star
18

Multi-Frame-Rendering-in-Diffusers

7
star
19

stable-diffusion-xl-handbook

6
star
20

mmdet_benchmark

mmdetection、mmdeploy 中的 Mask R-CNN 深度优化
Python
5
star
21

Anime-Facial-Landmarks

Python
4
star
22

lora-block-weight-diffusers

When applying Lora, strength can be set block by block. Support for diffusers framework.
Python
3
star
23

mxnet-Hand-Detection

A simple headmap regression for hand detection
Python
2
star
24

CS188-Project

CS188 Project Fall 2017 Berkeley
Python
1
star
25

pytorch-distributed-training

A simple cookbook for DDP training in Pytorch
Python
1
star
26

KGRN-SR

Official Implementation for Knowledge Graph Routed Network for Situation Recognition [TPAMI'2023]
Python
1
star
27

SD3-diffusers

Stable Diffusion 3 in diffusers
1
star