StyleDrop
This is an unofficial PyTorch implementation of StyleDrop: Text-to-Image Generation in Any Style.
Unlike the parameters in the paper in (Round 1), we set d_prj=32
, is_shared=False
, which we found work better, these hyperparameters can be seen in configs/custom.py
.
we release them to facilitate community research.
News
- [07/06/2023] Online Gradio Demo is available here
Todo List
- Release the code.
- Add gradio inference demo (runs in local).
- Add iterative training (Round 2).
Data & Weights Preparation
First, download VQGAN from this link (from MAGE, thanks!), and put the downloaded VQGAN in assets/vqgan_jax_strongaug.ckpt
.
Then, download the pre-trained checkpoints from this link to assets/ckpts
for evaluation or to continue training for more iterations.
finally, prepare empty_feature by runnig command python extract_empty_feature.py
And the final directory structure is as follows:
.
โโโ assets
โ โโโ ckpts
โ โ โโโ cc3m-285000.ckpt
โ โ โ โโโ lr_scheduler.pth
โ โ โ โโโ nnet_ema.pth
โ โ โ โโโ nnet.pth
โ โ โ โโโ optimizer.pth
โ โ โ โโโ step.pth
โ โ โโโ imagenet256-450000.ckpt
โ โ โโโ lr_scheduler.pth
โ โ โโโ nnet_ema.pth
โ โ โโโ nnet.pth
โ โ โโโ optimizer.pth
โ โ โโโ step.pth
โ โโโ fid_stats
โ โ โโโ fid_stats_cc3m_val.npz
โ โ โโโ fid_stats_imagenet256_guided_diffusion.npz
โ โโโ pipeline.png
| โโโ contexts
โ โ โโโ empty_context.npy
โโโ โโโ vqgan_jax_strongaug.ckpt
Dependencies
Same as MUSE-PyTorch.
conda install pytorch torchvision torchaudio cudatoolkit=11.3
pip install accelerate==0.12.0 absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1 loguru webdataset==0.2.5 gradio
Train
All style data in the paper are placed in the data directory
- Modify
data/one_style.json
(It should be noted thatone_style.json
andstyle data
must be in the same directory), The format isfile_name:[object,style]
{"image_03_05.jpg":["A bear","in kid crayon drawing style"]}
- Training script as follows.
#!/bin/bash
unset EVAL_CKPT
unset ADAPTER
export OUTPUT_DIR="output_dir/for/this/experiment"
accelerate launch --num_processes 8 --mixed_precision fp16 train_t2i_custom_v2.py --config=configs/custom.py
Inference
The pretrained style_adapter weights can be downloaded from ๐ค Hugging Face.
#!/bin/bash
export EVAL_CKPT="assets/ckpts/cc3m-285000.ckpt"
export ADAPTER="path/to/your/style_adapter"
export OUTPUT_DIR="output/for/this/experiment"
accelerate launch --num_processes 8 --mixed_precision fp16 train_t2i_custom_v2.py --config=configs/custom.py
Gradio Demo
Put the style_adapter weights in ./style_adapter
folder and run the following command will launch the demo:
python gradio_demo.py
The demo is also hosted on HuggingFace.
Citation
@article{sohn2023styledrop,
title={StyleDrop: Text-to-Image Generation in Any Style},
author={Sohn, Kihyuk and Ruiz, Nataniel and Lee, Kimin and Chin, Daniel Castro and Blok, Irina and Chang, Huiwen and Barber, Jarred and Jiang, Lu and Entis, Glenn and Li, Yuanzhen and others},
journal={arXiv preprint arXiv:2306.00983},
year={2023}
}
Acknowlegment
- The implementation is based on MUSE-PyTorch
- Many thanks for the generous help from Zanlin Ni