• Stars
    star
    112
  • Rank 312,240 (Top 7 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

CLOOB Conditioned Latent Diffusion training and inference code

CLOOB Conditioned Latent Diffusion: Convenient High Quality Diffusion Models

Introduction

This repository contains the training code for CLOOB conditioned latent diffusion. CCLD is similar in approach to the CLIP conditioned diffusion trained by Katherine Crowson with a few key differences:

  • The use of latent diffusion cuts training costs by something like a factor of ten, allowing a high quality 1.2 billion parameter model to converge in as few as 5 days on a single 8x A100 pod.

  • CLOOB conditioning can take advantage of CLOOB's unified latent space. CLOOB text and image embeds on the same inputs share a high similarity of somewhere around 0.9. This makes it possible to train the model without captions by using image embeds in the training loop and text embeds during inference.

This combination of traits makes the CCLD training approach extremely attractive to hobbyists, academics, and newcomers due to its high quality results, low finetune/training costs, and easy setup. It is the StyleGAN of diffusion models.

Pretrained Models

We currently have two models and plan to release more in the near future. Right now there is a 1.2 billion parameter classifier-free-guidance model trained on yfcc 100m:

yfcc_cfg (ViT-B/16 CLOOB 16 epochs, 192 base channels, 4-4-8-8 resolution multipliers) - CLOOB checkpoint | Autoencoder | Autoencoder Config | Model Mirror

danbooru_cfg (ViT-B/16 CLOOB 32 epochs, 128 base channels, 4-4-8-8 resolution multipliers)

And a stage one LAION 5b autoencoder which makes a good general base to train your latent diffusion model on top of if you can't train your own. LAION 5b contains a wide variety of images and should therefore have textures for your dataset in its distribution:

LAION 5b Autoencoder (autoencoder scale 8.0779) - Config

Note: The LAION 5b autoencoder was not trained on all of LAION 5b, but the laion2b-en and laion1b-nolang subsets.

Danbooru Autoencoder (autoencoder scale 9.3154) - Config

Setup

First recursively git clone this repo to get it and its submodules:

git clone --recursive https://github.com/JD-P/cloob-latent-diffusion

If you don't already have pytorch you'll need to install it, for most datacenter GPUs the command looks like:

pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

Then pip install our other dependencies:

pip3 install omegaconf pillow pytorch-lightning einops wandb ftfy regex pycocotools ./CLIP

You are now ready to sample or prepare your training run.

Sampling

It is possible to sample from a model like so:

rm -f out*.png; ./cfg_sample.py "A photorealist detailed snarling goblin" --autoencoder kl_f8 --checkpoint yfcc-latent-diffusion-f8-e2-s250k.ckpt -n 128 --seed 4485 && v-diffusion-pytorch/make_grid.py out_*.png

Or in the case of something like the danbooru latent diffusion model:

rm -f out*.png; ./cfg_sample.py "anime portrait of a man in a flight jacket leaning against a biplane" --autoencoder danbooru-kl-f8 --checkpoint danbooru-latent-diffusion-e88.ckpt --cloob-checkpoint cloob_laion_400m_vit_b_16_32_epochs --base-channels 128 --channel-multipliers 4,4,8,8 -n 16 --seed 4485 && v-diffusion-pytorch/make_grid.py out_*.png

Training

Preparing The Dataset

First prepare your training set by creating a .txt of filepaths that the images to train on will be loaded from. For example this is how you make such a list for the MS COCO dataset:

find /datasets/coco/train2017/ -type f >> train_paths.txt
find /datasets/coco/val2017/ -type f >> train_paths.txt
shuf train_paths.txt > train_paths_2.txt
mv train_paths_2.txt train_paths.txt 

The find command is run over the top level directory where images are stored in the dataset. The -type f flag filters the search so that only files are returned, if the images are stored only with other images this is equivalent to getting the filepaths for every image in the dataset by themselves. If the data is not conveniently organized this way it is possible to do further filtering by piping the results of find into utilities like grep.

Training Tip: It's important to shuffle your dataset so that the net generalizes during training. This is why the shuf utility is used on the training paths.

Demo Prompts

You will also need demo prompts for the grids displayed in wandb during your training run. These grids are cheap to generate with PLMS sampling and massively improve your ability to diagnose problems with your run. Here's some written by us:

A portrait of Friedrich Nietzsche wearing an open double breasted suit with a bowtie
A portrait of a man in a flight jacket leaning against a biplane
a vision of paradise. unreal engine
the gateway between dreams, trending on ArtStation
A fantasy painting of a city in a deep valley by Ivan Aivazovsky
a rainy city street in the style of cyberpunk noir, trending on ArtStation
An oil painting of A Vase Of Flowers
oil painting of a candy dish of glass candies, mints, and other assorted sweets
The Human Utility Function
the Tower of Babel by J.M.W. Turner
sketch of a 3D printer by Leonardo da Vinci
The US Capitol Building in the style of Kandinsky
Metaphysics in the style of WPAP
a watercolor painting of a Christmas tree
control room monitors televisions screens computers hacker lab, concept art, matte painting, trending on artstation
illustration of airship zepplins in the skies, trending on artstation

Training Tip: You may want to modify these prompts if you're training on a photorealistic dataset, as these are optimized more for getting results from models that do illustration and paintings.

Autoencoder

In order to train latent diffusion you need a latent space to train in. The architecture of the training code is set up for an f=8 KL autoencoder. You can get a photorealistic autoencoder here among with others in the CompVis latent diffusion repo. You will also need the configuration file for it which can be found in the latent-diffusion repo recursively cloned along with cloob-latent-diffusion. It should have the same name as your autoencoder with the file extension changed. For example:

cp latent-diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml ./2022_04_04_wikiart_kl_f8.yaml

Before training you must get the scale for your autoencoder like so:

python3 autoencoder_scale.py 2022_04_04_wikiart_kl_f8 train.txt

Write down the number you obtain from this and use it in your training run, this same number must be used in inference for the model to work. The model checkpoint retains a copy of the autoencoder scale but it's best to keep your own record of it in your lab notes.

If you're not training on a photorealistic dataset, you will either need to find an appropriate pretrained KL autoencoder or train your own. The training repo for these models is unfortunately pretty nasty for a beginner and requires modification before you can easily train an arbitrary dataset with it. We plan to release some pretrained models of our own along with a more friendly fork of that repo in the future.

Training Tip: From a compute perspective if you only have an A6000 or 3090 your best bet is probably to finetune an existing KL f=8 autoencoder on the dataset you want to train on. This still requires working training code however.

Training Tip: You must(?) use a low dimensional autoencoder for latent diffusion to work, our experiments with higher dimensional autoencoders did not work well.

Training The Model

Once you have the setup, training set, autoencoder, demo prompts, and wandb project ready starting the training run is as simple as:

python3 train_latent_diffusion.py --train-set train.txt --vqgan-model kl_f8 --autoencoder-scale 8.0779 --demo-prompts demo_prompts.txt --wandb-project jdp-latent-diffusion --batch-size 128 --num-gpus 8

For the YFCC CLOOB conditioned latent diffusion training took about five and a half days to reach the 250k checkpoint with a base channel count of 192 and channel multipliers of 4,4,8,8. You can analyze the logs from these runs at the following links:

0-150k step training run

150k-250k step training run

Training Tip: Your model is likely to overfit/memorize the training set if it's too big in relation to your dataset size. The rule of thumb for overfitting is the parameter count shouldn't be more than 2/3 the datapoints in the set. You can calculate datapoints (floats) from the size of your latents times the size of your dataset. For the f=8 kl autoencoder used by this training repo it's 32x32x4xDataSetSize. So for example WikiArt which has 80k training items should be trained on a model no more than 0.66 * 32 * 32 * 4 * 80000 parameters large, or 216.2688 million. You should pick your base channel count and channel multipliers to respect this rule. Base channel count must be a multiple of 64 for this architecture.

Training Tip: The loss curve has a small scale past the initial warmup, if it seems to be stuck in the same loss regime this doesn't necessarily mean it isn't improving. Make sure to use your demo grids to monitor progress.

Training Tip: It's possible to train in fp16 and then resume in fp32 once the run begins to explode or diverge. This is especially useful if you're VRAM constrained and would like to use a higher batch size in the early training. It also makes early training go faster if you're compute constrained or impatient.

Training Tip: Once the loss converges it is often possible to get it down lower by restarting the run with a lower learning rate. You need to overwrite the learning rate in the checkpoint so it doesn't get overwritten when you resume. You can do that from a python prompt like so:

Python 3.8.10 (default, Nov 26 2021, 20:14:08) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> ckpt = torch.load("jdp-latent-diffusion/1dv7xxrg/checkpoints/epoch=1-step=149999.ckpt")
>>> ckpt['optimizer_states'][0]['param_groups'][0]['lr']
3e-05
>>> ckpt['optimizer_states'][0]['param_groups'][0]['lr'] = 3e-06
>>> torch.save(ckpt, "yfcc_resume.ckpt")
>>>

Finetuning

It's possible to save time (and money) by retraining an existing model on a new dataset rather than starting from scratch. This is called finetuning a model. If you would like to finetune an existing model this is easily accomplished using the --resume-from flag:

python3 train_latent_diffusion.py --train-set train_paths.txt --vqgan-model kl_f8 --autoencoder-scale 8.0779 --demo-prompts coco_demo_prompts.txt --resume-from to_finetune.ckpt --wandb-project jdp-latent-diffusion

Training Tip: As a rule of thumb, finetunes tend to take 10-20% of the resources that the original training run did in compute time.

More Repositories

1

simulacra-aesthetic-captions

Dataset of prompts, synthetic AI generated images, and aesthetic ratings.
391
star
2

minihf

MiniHF is an inference, human preference data collection, and fine-tuning tool for local language models. It is intended to help the user develop their prompts into full models.
Python
145
star
3

simulacrabot

Discord AI Generation Bot to collect an aesthetic rating dataset
Python
61
star
4

hypothesis-zotero

A python tool that imports annotations made in Hypothesis (https://hypothes.is) to Zotero (https://www.zotero.org).
Python
59
star
5

RetroInstruct

Synthetic data derived by templating, few shot prompting, transformations on public domain corpora, and monte carlo tree search.
Python
18
star
6

miniloom-sessions

Public domain loom traces made with the MiniHF loom
5
star
7

epiphanal

A program for remembering things probabilistically.
Python
3
star
8

graverobber

Python + HTML5/JS implementation of the game Graverobber, as it appears in the webseries Petscop. See: https://petscop.fandom.com/wiki/Graverobber
Python
1
star
9

pomodoro-bot

Pomodoro bot for IRC.
Python
1
star
10

HNStoryRecommendations

Machine learning filter that uses the set of stories a Hacker News user has upvoted to recommend them stories from the Hacker News story database.
Python
1
star
11

makerbot-log-analysis

Tool to convert makerbot user_stats.json files into csv.
Python
1
star
12

discord-hypothesis

Discord bot that reports recent activity in a https://hypothes.is group.
Python
1
star
13

cross-correlate

Utility for exploring file formats produced by a blackbox from an input file using frequency analysis and spatial analysis.
Python
1
star
14

genlog

A generic logging program that lets users create log templates for things you'd normally keep track of on paper like timesheets, server changes or even how long you spend every day brushing your teeth.
Python
1
star
15

hn-job-sorter

Command line toolkit to sort and search the job postings on HN.
Python
1
star
16

lw-survey-analysis

2016 LessWrong Survey analysis code.
Python
1
star
17

diplomacy-assistant

A utility that helps someone playing diplomacy see possible gamestates for the next turn.
Python
1
star
18

accordius

Whistling Lobsters 2.0 backend with Django and Postgres
Python
1
star