• Stars
    star
    207
  • Rank 188,682 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

v objective diffusion inference code for JAX.

v-diffusion-jax

v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman).

The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI).

Thank you to Google's TPU Research Cloud and stability.ai for compute to train these models!

Dependencies

  • JAX (installation instructions)

  • dm-haiku, einops, numpy, optax, Pillow, tqdm (install with pip install)

  • CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). If you git clone --recursive this repo, it should fetch CLIP_JAX automatically.

Model checkpoints:

  • Danbooru SFW 128x128, SHA-256 8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334

  • ImageNet 128x128, SHA-256 4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530

  • WikiArt 128x128, SHA-256 8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a

  • WikiArt 256x256, SHA-256 ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81

Sampling

Example

If the model checkpoints are stored in checkpoints/, the following will generate an image:

./clip_sample.py "a friendly robot, watercolor by James Gurney" --model wikiart_256 --seed 0

If they are somewhere else, you need to specify the path to the checkpoint with --checkpoint.

Unconditional sampling

usage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model
                 {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
                 [--steps STEPS]

--batch-size: sample this many images at a time (default 1)

--checkpoint: manually specify the model checkpoint file

--eta: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.

--init: specify the init image (optional)

--model: specify the model to use

-n: sample until this many images are sampled (default 1)

--seed: specify the random seed (default 0)

--starting-timestep: specify the starting timestep if an init image is used (range 0-1, default 0.9)

--steps: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)

CLIP guided sampling

CLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt.

usage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT]
                      [--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model
                      {danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
                      [--steps STEPS]
                      prompt

clip_sample.py has the same options as sample.py and these additional ones:

prompt: the text prompt to use

--clip-guidance-scale: how strongly the result should match the text prompt (default 1000)

More Repositories

1

k-diffusion

Karras et al. (2022) diffusion models for PyTorch
Python
1,263
star
2

v-diffusion-pytorch

v objective diffusion inference code for PyTorch.
Python
708
star
3

style-transfer-pytorch

Neural style transfer in PyTorch.
Python
442
star
4

simulacra-aesthetic-models

Python
130
star
5

style_transfer

Data-parallel image stylization using Caffe.
Python
112
star
6

deep_dream

A parallel implementation of the Deep Dream image processing algorithm which is able to process arbitrarily large images.
Jupyter Notebook
100
star
7

consistency-models

A JAX implementation of the continuous time formulation of Consistency Models
Python
74
star
8

LDLM

Latent Diffusion Language Models
Python
66
star
9

cloob-training

CLOOB training (JAX) and inference (JAX and PyTorch)
Python
63
star
10

esgd

ESGD-M is a stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch.
Python
58
star
11

mdmm

The Modified Differential Multiplier Method (MDMM) for PyTorch
Python
41
star
12

vgg_loss

A VGG-based perceptual loss function for PyTorch.
Python
38
star
13

jax-wavelets

The 2D discrete wavelet transform for JAX
Python
25
star
14

cond_transformer_2

A CLIP conditioned Decision Transformer.
Python
22
star
15

clip-guided-diffusion

CLIP Guided Diffusion
Python
14
star
16

mdmm-jax

Gradient-based constrained optimization for JAX
Python
14
star
17

tv-denoise

Total variation denoising for images.
Python
13
star
18

shared_ndarray

A pickleable wrapper for sharing NumPy ndarrays between processes using POSIX shared memory.
Python
13
star
19

pytorch-caffe-models

The original weights of some Caffe models, ported to PyTorch.
Python
10
star
20

pharmacokinetics

A Flask web application to calculate and plot drug concentration over time.
Python
10
star
21

pyparsing-highlighting

Syntax highlighting for prompt_toolkit and HTML with pyparsing.
Python
9
star
22

aiohttp_index

aiohttp.web middleware to serve index files (e.g. index.html) when static directories are requested.
Python
8
star
23

rope-flax

Rotary Position Embedding for Flax
Python
4
star
24

ucs

Implements the CAM02-UCS (Luo et al. (2006)) forward transform.
Python
4
star
25

philips-hue

A CLI tool to interface with Philips Hue lights.
Python
4
star
26

dice-mc

DiCE: The Infinitely Differentiable Monte-Carlo Estimator
Python
3
star
27

average

Exponentially weighted moving averages with initialization bias correction.
Python
3
star
28

synthraw

Synthesizes camera raw files
Python
2
star
29

huething

This is a work in progress to control my four Philips Hue bulbs.
Python
2
star
30

base58

Package base58 implements base58 encoding as used in Bitcoin addresses.
Go
2
star
31

crowsonkb.github.io

HTML
2
star
32

websynth

pNaCl-Csound based softsynth
JavaScript
2
star
33

color_schemer

A web application to translate color schemes between dark- and light-background.
Python
1
star
34

cluster

Package cluster performs hierarchical clustering of term vectors.
Go
1
star
35

gradient-maker3

A web application to generate color gradients using the CAM02-UCS colorspace.
Python
1
star
36

scihub-lookup

A Safari extension to look up the current page on Sci-Hub
JavaScript
1
star
37

fragments

Miscellaneous useful, reusable code fragments
Python
1
star
38

.zsh

My zsh configuration
Shell
1
star
39

randomness

Generates random secrets (passwords, etc).
Python
1
star