• Stars
    star
    708
  • Rank 63,520 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

v objective diffusion inference code for PyTorch.

v-diffusion-pytorch

v objective diffusion inference code for PyTorch, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman).

The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI). Guided diffusion sampling scripts (https://arxiv.org/abs/2105.05233) are included, specifically CLIP guided diffusion. This repo also includes a diffusion model conditioned on CLIP text embeddings that supports classifier-free guidance (https://openreview.net/pdf?id=qw8AKxfYbI), similar to GLIDE (https://arxiv.org/abs/2112.10741). Sampling methods include DDPM, DDIM (https://arxiv.org/abs/2010.02502), and PRK/PLMS (https://openreview.net/forum?id=PlKWVd2yBkY).

Thank you to stability.ai for compute to train these models!

Installation

pip install v-diffusion-pytorch

or git clone then pip install -e .

Model checkpoints:

  • CC12M_1 CFG 256x256, SHA-256 4fc95ee1b3205a3f7422a07746383776e1dbc367eaf06a5b658ad351e77b7bda

A 602M parameter CLIP conditioned model trained on Conceptual 12M for 3.1M steps and then fine-tuned for classifier-free guidance for 250K additional steps. This is the recommended model to use.

  • CC12M_1 256x256, SHA-256 63946d1f6a1cb54b823df818c305d90a9c26611e594b5f208795864d5efe0d1f

As above, before CFG fine-tuning. The model from the original release of this repo.

  • YFCC_1 512x512, SHA-256 a1c0f6baaf89cb4c461f691c2505e451ff1f9524744ce15332b7987cc6e3f0c8

A 481M parameter unconditional model trained on a 33 million image original resolution subset of Yahoo Flickr Creative Commons 100 Million.

  • YFCC_2 512x512, SHA-256 69ad4e534feaaebfd4ccefbf03853d5834231ae1b5402b9d2c3e2b331de27907

A 968M parameter unconditional model trained on a 33 million image original resolution subset of Yahoo Flickr Creative Commons 100 Million.

It also contains PyTorch ports of the four models from v-diffusion-jax, danbooru_128, imagenet_128, wikiart_128, wikiart_256:

  • Danbooru SFW 128x128, SHA-256 1728940d3531504246dbdc75748205fd8a24238a17e90feb82a64d7c8078c449

  • ImageNet 128x128, SHA-256 cac117cd0ed80390b2ae7f3d48bf226fd8ee0799d3262c13439517da7c214a67

  • WikiArt 128x128, SHA-256 b3ca8d0cf8bd47dcbf92863d0ab6e90e5be3999ab176b294c093431abdce19c1

  • WikiArt 256x256, SHA-256 da45c38aa31cd0d2680d29a3aaf2f50537a4146d80bba2ca3e7a18d227d9b627

Sampling

Example

If the model checkpoint for cc12m_1_cfg is stored in checkpoints/, the following will generate four images:

./cfg_sample.py "the rise of consciousness":5 -n 4 -bs 4 --seed 0

If they are somewhere else, you need to specify the path to the checkpoint with --checkpoint.

Colab

There is a cc12m_1_cfg Colab (a simplified version of cfg_sample.py) here, which can be used for free.

CFG sampling (best, but only cc12m_1_cfg supports it)

usage: cfg_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE]
                     [--checkpoint CHECKPOINT] [--device DEVICE] [--eta ETA] [--init INIT]
                     [--method {ddpm,ddim,prk,plms,pie,plms2,iplms}] [--model {cc12m_1_cfg}]
                     [-n N] [--seed SEED] [--size SIZE SIZE]
                     [--starting-timestep STARTING_TIMESTEP] [--steps STEPS]
                     [prompts ...]

prompts: the text prompts to use. Weights for text prompts can be specified by putting the weight after a colon, for example: "the rise of consciousness:5". A weight of 1 will sample images that match the prompt roughly as well as images usually match prompts like that in the training set. The default weight is 3.

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--device: the PyTorch device name to use (default autodetects)

--eta: set to 0 (the default) while using --method ddim for deterministic (DDIM) sampling, 1 for stochastic (DDPM) sampling, and in between to interpolate between the two.

--images: the image prompts to use (local files or HTTP(S) URLs). Weights for image prompts can be specified by putting the weight after a colon, for example: "image_1.png:5". The default weight is 3.

--init: specify the init image (optional)

--method: specify the sampling method to use (DDPM, DDIM, PRK, PLMS, PIE, PLMS2, or IPLMS) (default PLMS). DDPM is the original SDE sampling method, DDIM integrates the probability flow ODE using a first order method, PLMS is fourth-order pseudo Adams-Bashforth, and PLMS2 is second-order pseudo Adams-Bashforth. PRK (fourth-order Pseudo Runge-Kutta) and PIE (second-order Pseudo Improved Euler) are used to bootstrap PLMS and PLMS2 but can be used on their own if you desire (slow). IPLMS is the fourth order "Improved PLMS" sampler from (Fast Sampling of Diffusion Models with Exponential Integrator)[https://arxiv.org/abs/2204.13902].

--model: specify the model to use (default cc12m_1_cfg)

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--starting-timestep: specify the starting timestep if an init image is used (range 0-1, default 0.9)

--size: the output image size (default auto)

--steps: specify the number of diffusion timesteps (default is 50, can be lower for faster but lower quality sampling, must be much higher with DDIM and especially DDPM)

CLIP guided sampling (all models)

usage: clip_sample.py [-h] [--images [IMAGE ...]] [--batch-size BATCH_SIZE]
                      [--checkpoint CHECKPOINT] [--clip-guidance-scale CLIP_GUIDANCE_SCALE]
                      [--cutn CUTN] [--cut-pow CUT_POW] [--device DEVICE] [--eta ETA]
                      [--init INIT] [--method {ddpm,ddim,prk,plms,pie,plms2,iplms}]
                      [--model {cc12m_1,cc12m_1_cfg,danbooru_128,imagenet_128,wikiart_128,wikiart_256,yfcc_1,yfcc_2}]
                      [-n N] [--seed SEED] [--size SIZE SIZE]
                      [--starting-timestep STARTING_TIMESTEP] [--steps STEPS]
                      [prompts ...]

prompts: the text prompts to use. Relative weights for text prompts can be specified by putting the weight after a colon, for example: "the rise of consciousness:0.5".

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--clip-guidance-scale: how strongly the result should match the text prompt (default 500). If set to 0, the cc12m_1 model will still be CLIP conditioned and sampling will go faster and use less memory.

--cutn: the number of random crops to compute CLIP embeddings for (default 16)

--cut-pow: the random crop size power (default 1)

--device: the PyTorch device name to use (default autodetects)

--eta: set to 0 (the default) while using --method ddim for deterministic (DDIM) sampling, 1 for stochastic (DDPM) sampling, and in between to interpolate between the two.

--images: the image prompts to use (local files or HTTP(S) URLs). Relative weights for image prompts can be specified by putting the weight after a colon, for example: "image_1.png:0.5".

--init: specify the init image (optional)

--method: specify the sampling method to use (DDPM, DDIM, PRK, PLMS, PIE, PLMS2, or IPLMS) (default PLMS). DDPM is the original SDE sampling method, DDIM integrates the probability flow ODE using a first order method, PLMS is fourth-order pseudo Adams-Bashforth, and PLMS2 is second-order pseudo Adams-Bashforth. PRK (fourth-order Pseudo Runge-Kutta) and PIE (second-order Pseudo Improved Euler) are used to bootstrap PLMS and PLMS2 but can be used on their own if you desire (slow). IPLMS is the fourth order "Improved PLMS" sampler from (Fast Sampling of Diffusion Models with Exponential Integrator)[https://arxiv.org/abs/2204.13902].

--model: specify the model to use (default cc12m_1)

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--starting-timestep: specify the starting timestep if an init image is used (range 0-1, default 0.9)

--size: the output image size (default auto)

--steps: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)

More Repositories

1

k-diffusion

Karras et al. (2022) diffusion models for PyTorch
Python
1,263
star
2

style-transfer-pytorch

Neural style transfer in PyTorch.
Python
442
star
3

v-diffusion-jax

v objective diffusion inference code for JAX.
Python
207
star
4

simulacra-aesthetic-models

Python
130
star
5

style_transfer

Data-parallel image stylization using Caffe.
Python
112
star
6

deep_dream

A parallel implementation of the Deep Dream image processing algorithm which is able to process arbitrarily large images.
Jupyter Notebook
100
star
7

consistency-models

A JAX implementation of the continuous time formulation of Consistency Models
Python
74
star
8

LDLM

Latent Diffusion Language Models
Python
66
star
9

cloob-training

CLOOB training (JAX) and inference (JAX and PyTorch)
Python
63
star
10

esgd

ESGD-M is a stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch.
Python
58
star
11

mdmm

The Modified Differential Multiplier Method (MDMM) for PyTorch
Python
41
star
12

vgg_loss

A VGG-based perceptual loss function for PyTorch.
Python
38
star
13

jax-wavelets

The 2D discrete wavelet transform for JAX
Python
25
star
14

cond_transformer_2

A CLIP conditioned Decision Transformer.
Python
22
star
15

clip-guided-diffusion

CLIP Guided Diffusion
Python
14
star
16

mdmm-jax

Gradient-based constrained optimization for JAX
Python
14
star
17

tv-denoise

Total variation denoising for images.
Python
13
star
18

shared_ndarray

A pickleable wrapper for sharing NumPy ndarrays between processes using POSIX shared memory.
Python
13
star
19

pytorch-caffe-models

The original weights of some Caffe models, ported to PyTorch.
Python
10
star
20

pharmacokinetics

A Flask web application to calculate and plot drug concentration over time.
Python
10
star
21

pyparsing-highlighting

Syntax highlighting for prompt_toolkit and HTML with pyparsing.
Python
9
star
22

aiohttp_index

aiohttp.web middleware to serve index files (e.g. index.html) when static directories are requested.
Python
8
star
23

rope-flax

Rotary Position Embedding for Flax
Python
4
star
24

ucs

Implements the CAM02-UCS (Luo et al. (2006)) forward transform.
Python
4
star
25

philips-hue

A CLI tool to interface with Philips Hue lights.
Python
4
star
26

dice-mc

DiCE: The Infinitely Differentiable Monte-Carlo Estimator
Python
3
star
27

average

Exponentially weighted moving averages with initialization bias correction.
Python
3
star
28

synthraw

Synthesizes camera raw files
Python
2
star
29

huething

This is a work in progress to control my four Philips Hue bulbs.
Python
2
star
30

base58

Package base58 implements base58 encoding as used in Bitcoin addresses.
Go
2
star
31

crowsonkb.github.io

HTML
2
star
32

websynth

pNaCl-Csound based softsynth
JavaScript
2
star
33

color_schemer

A web application to translate color schemes between dark- and light-background.
Python
1
star
34

cluster

Package cluster performs hierarchical clustering of term vectors.
Go
1
star
35

gradient-maker3

A web application to generate color gradients using the CAM02-UCS colorspace.
Python
1
star
36

scihub-lookup

A Safari extension to look up the current page on Sci-Hub
JavaScript
1
star
37

fragments

Miscellaneous useful, reusable code fragments
Python
1
star
38

.zsh

My zsh configuration
Shell
1
star
39

randomness

Generates random secrets (passwords, etc).
Python
1
star