• Stars
    star
    180
  • Rank 211,860 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

[CVPR'20] Collaborative Distillation for Ultra-Resolution Universal Style Transfer (PyTorch)

Collaborative-Distillation

Official PyTorch code for our CVPR-20 poster paper "Collaborative Distillation for Ultra-Resolution Universal Style Transfer", where we propose a new knowledge distillation method to reduce VGG-19 filters, realizing the ultra-resolution universal style transfer on a single 12GB GPU. We focus on model compression instead of new stylization schemes. For stylization, our method builds upon WCT.

One stylized sample of 10240 x 4096 pixels

Environment

  • OS: Linux (Ubuntu 1404 and 1604 checked. It should be all right for most linux platforms. Windows and MacOS not checked.)
  • python==3.6.9 (conda to manage environment is suggested)
  • The needed libraries are summarized in requirements.txt. Simply install them by pip install -r requirements
  • CUDA (cuDNN is not necessary)

After the installlations, download the code:

git clone https://github.com/MingSun-Tse/Collaborative-Distillation.git

Test (style transfer)

Step 1: Prepare images

  • Contents are placed in PytorchWCT/content, where ultra-res contents are placed in PytorchWCT/content/UHD_content. Same path setting for the styles.
  • Since the ultra-res images can be quite large, we only place two samples in this repo. For more ultra-res contents and styles presented in our paper, please download them from this google drive.

Image copyrights: We use the UHD images from this wallpaper website. All copyrights are attributed to them and thanks to them!

Step 2: Prepare models

  • For original WCT: Download the unpruned models (which are from the official WCT implementation). Unzip and place them under trained_models/original_wct_models.
  • For ultra-resolution WCT: We use our pruned VGG-19. The models are already in the trained_models/wct_se_16x_new (for encoders) and trained_models/wct_se_16x_new_sd (for decoders).

Step 3: Stylization

  • Under the PytorchWCT folder, please run the following scripts. The stylized results will be saved in PytorchWCT/stylized_results.
# use original VGG-19, normal images
CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode original

# use original VGG-19, ultra-res images
CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode original --UHD

# use our pruned VGG-19, normal images
CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode 16x

# use our pruned VGG-19, ultra-res images
CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode 16x --UHD

# If your RAM cannot afford some large images, you can change the content and style size via '--content_size' and '--style_size'
CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode 16x --UHD --content_size 3000 --style_size 2000

In default, the above scripts will test all possible content-style combinations (i.e., for 3 contents with 4 styles, there will be 3x4 stylized results). If you only want to test a specific pair, say, content "green_park-wallpaper-3840x2160.jpg" with style "Vincent_2K.jpg", you can use the option --picked_content_mark and --picked_style_mark to select specific pairs. E.g., the following will only choose the content whose name includes "green_park" and the style whose name includes "Vincent".

CUDA_VISIBLE_DEVICES=0 python WCT.py --debug --mode 16x --UHD --picked_content_mark green_park --picked_style_mark Vincent

Train (model compression)

Step 1: Prepare dataset

Download the MS-COCO 2014 training set and unzip it at path data/COCO/train2014.

Step 2: Prepare models

For training the SE (small encoder), we need the original decoder (big decoder or BD). We trained our own BD following the WCT paper. You can download them from this google drive and put them at path trained_models/our_BD.

Step 3: Train the compressed encoders

Under the root folder, run

CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_se --pretrained_init --screen --stage 5 -p wct_se_stage5
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_se --pretrained_init --screen --stage 4 -p wct_se_stage4
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_se --pretrained_init --screen --stage 3 -p wct_se_stage3
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_se --pretrained_init --screen --stage 2 -p wct_se_stage2
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_se --pretrained_init --screen --stage 1 -p wct_se_stage1
  • The log and trained models will be saved in a new-built project folder under Experiments.
  • --pretrained_init is to indicate using base models for initialization, which are obtained by pruning the filters with the least L1-norms (see also 2017-ICLR-Filter Pruning)

Step 4: Train the corresponding decoders

CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_sd --pretrained_init --screen --lw_perc 0.01 --stage 5 -p wct_sd_stage5 --SE <SE path>
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_sd --pretrained_init --screen --lw_perc 0.01 --stage 4 -p wct_sd_stage4 --SE <SE path>
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_sd --pretrained_init --screen --lw_perc 0.01 --stage 3 -p wct_sd_stage3 --SE <SE path>
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_sd --pretrained_init --screen --lw_perc 0.01 --stage 2 -p wct_sd_stage2 --SE <SE path>
CUDA_VISIBLE_DEVICES=0 python main.py --mode wct_sd --pretrained_init --screen --lw_perc 0.01 --stage 1 -p wct_sd_stage1 --SE <SE path>
  • <SE path> is to specify the small encoder model trained in Step 3. A path example for stage5 is Experiments/*wct_se_stage5*/weights/*.pth

Results

Acknowledgments

In this code we refer to the following implementations: PytorchWCT, UniversalStyleTransfer, pytorch-AdaIN, AdaIN-style. Great thanks to them!

Reference

Please cite this in your publication if our work helps your research. Should you have any questions, welcome to reach out to Huan Wang ([email protected]).

@inproceedings{wang2020collaborative,
  Author = {Wang, Huan and Li, Yijun and Wang, Yuehai and Hu, Haoji and Yang, Ming-Hsuan},
  Title = {Collaborative Distillation for Ultra-Resolution Universal Style Transfer},
  Booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  Year = {2020}
}

More Repositories

1

Efficient-Deep-Learning

Collection of recent methods on (deep) neural network compression and acceleration.
915
star
2

Regularization-Pruning

[ICLR'21] Neural Pruning via Growing Regularization (PyTorch)
Python
73
star
3

ASSL

[NeurIPS'21 Spotlight] Aligned Structured Sparsity Learning for Efficient Image Super-Resolution (PyTorch)
Python
59
star
4

Awesome-Pruning-at-Initialization

[IJCAI'22 Survey] Recent Advances on Neural Network Pruning at Initialization.
44
star
5

Smile-Pruning

A generic code base for neural network pruning, especially for pruning at initialization.
Python
30
star
6

Good-DA-in-KD

[NeurIPS'22] What Makes a "Good" Data Augmentation in Knowledge Distillation -- A Statistical Perspective
Python
29
star
7

Why-the-State-of-Pruning-so-Confusing

[Preprint] Why is the State of Neural Network Pruning so Confusing? On the Fairness, Comparison Setup, and Trainability in Network Pruning
29
star
8

smilelogging

Python logging package for easy reproducible experimenting in research
Python
25
star
9

TPP

[ICLR'23] Trainability Preserving Neural Pruning (PyTorch)
Python
23
star
10

Awesome-Efficient-ViT

Recent Advances on Efficient Vision Transformers
22
star
11

SRP

[ICLR'22] PyTorch code for our paper "Learning Efficient Image Super-Resolution Networks via Structure-Regularized Pruning"
Python
18
star
12

Caffe_IncReg

[IJCNN'19, IEEE JSTSP'19] Caffe code for our paper "Structured Pruning for Efficient ConvNets via Incremental Regularization"; [BMVC'18] "Structured Probabilistic Pruning for Convolutional Neural Network Acceleration"
Makefile
14
star
13

WritingTips

"Good scientific writing is not a matter of life and death; it is much more serious than that."
TeX
7
star
14

Efficient-NeRF

Python
7
star
15

UtilsHub

Python
3
star
16

LowlevelVision

paper collection for low-level vision
3
star
17

AdversarialAttacks

2
star