• Stars
    star
    1,225
  • Rank 38,289 (Top 0.8 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Unofficial implementation of Palette: Image-to-Image Diffusion Models by Pytorch

Palette: Image-to-Image Diffusion Models

Paper | Project

Brief

This is an unofficial implementation of Palette: Image-to-Image Diffusion Models by Pytorch, and it is mainly inherited from its super-resolution version Image-Super-Resolution-via-Iterative-Refinement. The code template is from my another seed project: distributed-pytorch-template.

There are some implementation details with paper descriptions:

  • We adapted the U-Net architecture used in Guided-Diffusion, which give a substantial boost to sample quality.
  • We used the attention mechanism in low-resolution features (16×16) like vanilla DDPM.
  • We encode the $\gamma$ rather than $t$ in Palette and embed it with affine transformation.
  • We fix the variance $Σ_\theta(x_t, t)$ to a constant during the inference as described in Palette.

Status

Code

  • Diffusion Model Pipeline
  • Train/Test Process
  • Save/Load Training State
  • Logger/Tensorboard
  • Multiple GPU Training (DDP)
  • EMA
  • Metrics (now for FID, IS)
  • Dataset (now for inpainting, uncropping, colorization)
  • Google colab script 🌟(now for inpainting)

Task

I try to finish following tasks in order:

The follow-up experiment is uncertain, due to lack of time and GPU resources:

  • Uncropping on Places2
  • Colorization on ImageNet val set

Results

The DDPM model requires significant computational resources, and we have only built a few example models to validate the ideas in this paper.

Visuals

Celeba-HQ

Results with 200 epochs and 930K iterations, and the first 100 samples in centering mask and irregular mask.

Process_02323 Process_02323

Places2 with 128×128 centering mask

Results with 16 epochs and 660K iterations, and the several picked samples in centering mask.

Mask_Places365_test_00209019.jpg Mask_Places365_test_00143399.jpg Mask_Places365_test_00263905.jpg Mask_Places365_test_00144085.jpg
Out_Places365_test_00209019 Out_Places365_test_00143399.jpg Out_Places365_test_00263905.jpg Out_Places365_test_00144085.jpg

Uncropping on Places2

Results with 8 epochs and 330K iterations, and the several picked samples in uncropping.

Process_Places365_test_00309553 Process_Places365_test_00042384

Metrics

Tasks Dataset EMA FID(-) IS(+)
Inpainting with centering mask Celeba-HQ False 5.7873 3.0705
Inpainting with irregular mask Celeba-HQ False 5.4026 3.1221

Usage

Environment

pip install -r requirements.txt

Pre-trained Model

Dataset Task Iterations GPUs×Days×Bs URL
Celeba-HQ Inpainting 930K 2×5×3 Google Drive
Places2 Inpainting 660K 4×8×10 Google Drive

Bs indicates sample size per gpu.

Data Prepare

We get most of them from Kaggle, which may be slightly different from official version, and you also can download them from official website.

We use the default division of these datasets for training and evaluation. The file lists we use can be found in Celeba-HQ, Places2.

After you prepared own data, you need to modify the corresponding configure file to point to your data. Take the following as an example:

"which_dataset": {  // import designated dataset using arguments 
    "name": ["data.dataset", "InpaintDataset"], // import Dataset() class
    "args":{ // arguments to initialize dataset
    	"data_root": "your data path",
    	"data_len": -1,
    	"mask_mode": "hybrid"
    } 
},

More choices about dataloader and validation split also can be found in datasets part of configure file.

Training/Resume Training

  1. Download the checkpoints from given links.
  2. Set resume_state of configure file to the directory of previous checkpoint. Take the following as an example, this directory contains training states and saved model:
"path": { //set every part file path
	"resume_state": "experiments/inpainting_celebahq_220426_150122/checkpoint/100" 
},
  1. Set your network label in load_everything function of model.py, default is Network. Follow the tutorial settings, the optimizers and models will be loaded from 100.state and 100_Network.pth respectively.
netG_label = self.netG.__class__.__name__
self.load_network(network=self.netG, network_label=netG_label, strict=False)
  1. Run the script:
python run.py -p train -c config/inpainting_celebahq.json

We test the U-Net backbone used in SR3 and Guided Diffusion, and Guided Diffusion one have a more robust performance in our current experiments. More choices about backbone, loss and metric can be found in which_networks part of configure file.

Test

  1. Modify the configure file to point to your data following the steps in Data Prepare part.
  2. Set your model path following the steps in Resume Training part.
  3. Run the script:
python run.py -p test -c config/inpainting_celebahq.json

Evaluation

  1. Create two folders saving ground truth images and sample images, and their file names need to correspond to each other.

  2. Run the script:

python eval.py -s [ground image path] -d [sample image path]

Acknowledge

Our work is based on the following theoretical works:

and we are benefiting a lot from the following projects:

More Repositories

1

Image-Super-Resolution-via-Iterative-Refinement

Unofficial implementation of Image Super-Resolution via Iterative Refinement by Pytorch
Python
3,193
star
2

distributed-pytorch-template

This is a seed project for distributed PyTorch training, which was built to customize your network quickly
Python
104
star
3

Image-Zooming-Using-Directional-Cubic-Convolution-Interpolation

Unofficial Python implementation about Image Zooming Using Directional Cubic Convolution Interpolation (DCC) by Numpy.
Python
12
star
4

Cminus-Compiler

【Yacc】Cminus Compiler
C++
6
star
5

A-Demo-for-Image-Inpainting-by-React

This is a demo for image inpainting project by Flask and React
Python
6
star
6

face-mask-segmentation

【Tools】Face recognition and attribute segmentation using Python, dlib, and One Millisecond Face Alignment with an Ensemble of Regression Trees](http://www.cv-foundation.org/openaccess/content_cvpr_2014/papers/Kazemi_One_Millisecond_Face_2014_CVPR_paper.pdf).
Python
4
star
7

Collaborative-Drawboard

【SpringBoot】A collaboration drawing board using Fabric and SocketIO communication, and build the database platform
CSS
4
star
8

MobileClass

【Java】Mobile Class System
HTML
3
star
9

Face-Mask-Segmentation

【Tools】Face recognition and attribute segmentation using Python, dlib, and One Millisecond Face Alignment with an Ensemble of Regression Trees](http://www.cv-foundation.org/openaccess/content_cvpr_2014/papers/Kazemi_One_Millisecond_Face_2014_CVPR_paper.pdf).
Python
3
star
10

Janspiry

Processing
2
star
11

MNIST-PyTorch

【Pytorch】LeNet5 implement for image classification task on MNIST dataset by PyTorch
Jupyter Notebook
2
star
12

multivariate_normal

Scipy.stats.multivariate_normal.pdf() implementation by C++ with OpenCV.
C++
2
star
13

ProductWeb

【HTML】Demo for product promotion
JavaScript
2
star
14

AssetInvestment

【Python】Application of reinforcement learning in portfolio
Python
1
star
15

Shop-System

【Python】a simple shop that customers can select goods and view their order with Django
Python
1
star
16

Python-Game

【Python】 A right and left brain coordination flying game using PyGame
Python
1
star