• Stars
    star
    630
  • Rank 71,328 (Top 2 %)
  • Language
    Python
  • License
    Other
  • Created almost 3 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

PyTorch implementation of a 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions

minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for non-commercial purposes.

a painting of a bird in the style of asian painting a photo of san francisco's golden gate bridge in black and white tone

Environment Setup

  • Basic setup
PyTorch == 1.8.0
CUDA >= 10.1
  • Other packages
pip install -r requirements.txt

Model Checkpoint

  • Model structure (two-stage autoregressive model)
    • Stage1: Unlike the original DALL-E [1], we replace Discrete VAE with VQGAN [2] to generate high-quality samples effectively. We slightly fine-tune vqgan_imagenet_f16_16384, provided by the official VQGAN repository, on FFHQ [3] as well as ImageNet.
    • Stage2: We train our 1.3B transformer from scratch on 14 million image-text pairs from CC3M [4] and CC12M [5]. For the more detailed model spec, please see configs/dalle-1.3B.yaml.
  • You can download the pretrained models including the tokenizer from this link. This will require about 5GB space.

Sampling

  • Given a text prompt, the code snippet below generates candidate images and re-ranks them using OpenAI's CLIP [6].
  • This has been tested under a single V100 of 32GB memory. In the case of using GPUs with limited memory, please lower down num_candidates to avoid OOM.
from matplotlib import pyplot as plt
import clip
from dalle.models import Dalle
from dalle.utils.utils import set_seed, clip_score

device = 'cuda:0'
set_seed(0)

prompt = "A painting of a monkey with sunglasses in the frame"
model = Dalle.from_pretrained('minDALL-E/1.3B')  # This will automatically download the pretrained model.
model.to(device=device)

# Sampling
images = model.sampling(prompt=prompt,
                        top_k=256, # It is recommended that top_k is set lower than 256.
                        top_p=None,
                        softmax_temperature=1.0,
                        num_candidates=96,
                        device=device).cpu().numpy()
images = np.transpose(images, (0, 2, 3, 1))

# CLIP Re-ranking
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.to(device=device)
rank = clip_score(prompt=prompt,
                  images=images,
                  model_clip=model_clip,
                  preprocess_clip=preprocess_clip,
                  device=device)

# Plot images
images = images[rank]
plt.imshow(images[0])
plt.show()

Samples (Top-K=256, Temperature=1.0)

  • "a painting of a {cat, dog} with sunglasses in the frame"

  • "a large {pink, black} elephant walking on the beach"

  • "Eiffel tower on a {desert, mountain}"

Quantitative Results

  • We have validated minDALL-E on the CC3M validation set (in-distribution evaluation) and MS-COCO (zero-shot evaluation).
  • For CC3M, we measure the cosine similarity between image and text representations from the pretrained CLIP model (ViT-B/32), referred to as CLIP-score.
  • For MS-COCO, we compute FID between 30K generated and real samples from MS-COCO 2017, where we randomly choose 30K captions from COCO as in DALL-E. We select the best out of 32 candidates by CLIP re-ranking.
Model CC3M:CLIP-score (higher is better) MS-COCO:FID-30K (lower is better)
VQGAN [2] 0.20 -
ImageBART [7] 0.23 -
DALL-E [1] - 27.5
minDALL-E 0.26 14.7

Transfer Learning Examples

  • minDALL-E, which is pre-trained on noisy text supervisions, could be transferable to class-conditional and unconditional generation tasks. To validate this, we simply fine-tune it on ImageNet over 8 epochs in the case of class-conditional generation and unconditional generation.
  • The commands below fine-tune the pretrained DALL-E. It takes about 36 hours on 8 V100 GPUs.
# unconditinoal image generation for imagenet (256x256)
python examples/transfer_learning_ex.py -d=configs/transfer-imagenet-uncond-gen.yaml
                                        -u=[MODEL_CKPT]
                                        -r=[RESULT_PATH]
                                        --n-gpus=[NUM_GPUS]

# class-conditinoal image generation for imagenet (256x256)
python examples/transfer_learning_ex.py -d=configs/transfer-imagenet-clscond-gen.yaml
                                        -u=[MODEL_CKPT]
                                        -r=[RESULT_PATH]
                                        --n-gpus=[NUM_GPUS]
  • We compute FID-50K between 50K generated samples and all ImageNet training samples, where we use top-k=256 and softmax temperature=1.0 for generation. All results are obtained without the rejection sampling. Interestingly, our model achieves very competitive performance with baselines, even though minDALL-E is fine-tuned in a few epochs.
Model Params FID-50K(class-cond.) FID-50K(uncond.)
VQ-GAN 1.4B 15.78 -
ImageBART 3.5B 21.19 -
minDALL-E 1.3B 15.55 37.58

BibTex

If you find this repository useful in your research, please cite:

@misc{kakaobrain2021minDALL-E,
  title         = {minDALL-E on Conceptual Captions},
  author        = {Saehoon Kim, Sanghun Cho, Chiheon Kim, Doyup Lee, and Woonhyuk Baek},
  year          = {2021},
  howpublished  = {\url{https://github.com/kakaobrain/minDALL-E}},
}

References

  • [1] Ramesh et al. Zero-Shot Text-to-Image Generation, ICML 2021.
  • [2] Esser et al. Taming Transformers for High-Resolution Image Synthesis, CVPR 2021.
  • [3] Karras et al. A Style-Based Generator Architecture for Generative Adversarial Networks, CVPR 2019.
  • [4] Sharma et al. Conceptual Captions: A Cleaned, Hypernymed, Image Alt-text Dataset For Automatic Image Captioning, ACL 2018.
  • [5] Changpinyo et al. Conceptual 12M: Pushing Web-Scale Image-Text Pre-Training To Recognize Long-Tail Visual Concepts, CVPR 2021.
  • [6] Radford et al. Learning Transferable Visual Models From Natural Language Supervision, ICML 2021.
  • [7] Esser et al. ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive Image Synthesis, NeurIPS 2021.
  • [8] https://github.com/karpathy/minGPT

Licenses

  • The source codes are licensed under Apache 2.0 License.
  • The stage2 pretrained weights are licensed under CC-BY-NC-SA 4.0 License.

Contact

We hope that minDALL-E helps various projects in research-oriented institutes and startups. If you would like to collaborate with us or share a feedback, please e-mail to us, [email protected]

Limitations

Although minDALL-E is trained on a small set (14M image-text pairs), this might be vulnerable to malicious attacks from the prompt engineering to generate socially unacceptable images. If you obersve these images, please report the "prompt" and "generated images" to us.

More Repositories

1

fast-autoaugment

Official Implementation of 'Fast AutoAugment' in PyTorch.
Python
1,587
star
2

nerf-factory

An awesome PyTorch NeRF library
Python
1,265
star
3

pororo

PORORO: Platform Of neuRal mOdels for natuRal language prOcessing
Python
1,252
star
4

coyo-dataset

COYO-700M: Large-scale Image-Text Pair Dataset
Python
1,062
star
5

kogpt

KakaoBrain KoGPT (Korean Generative Pre-trained Transformer)
Python
1,000
star
6

torchgpipe

A GPipe implementation in PyTorch
Python
776
star
7

karlo

Python
679
star
8

rq-vae-transformer

The official implementation of Autoregressive Image Generation using Residual Quantization (CVPR '22)
Jupyter Notebook
669
star
9

word2word

Easy-to-use word-to-word translations for 3,564 language pairs.
Python
350
star
10

torchlars

A LARS implementation in PyTorch
Python
326
star
11

g2pm

A Neural Grapheme-to-Phoneme Conversion Package for Mandarin Chinese Based on a New Open Benchmark Dataset
Python
326
star
12

kor-nlu-datasets

KorNLI and KorSTS: New Benchmark Datasets for Korean Natural Language Understanding
283
star
13

trident

A performance library for machine learning applications.
Python
176
star
14

autoclint

A specially designed light version of Fast AutoAugment
Python
170
star
15

sparse-detr

PyTorch Implementation of Sparse DETR
Python
150
star
16

hotr

Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)
Python
132
star
17

kortok

The code and models for "An Empirical Study of Tokenization Strategies for Various Korean NLP Tasks" (AACL-IJCNLP 2020)
Python
114
star
18

bassl

Python
113
star
19

scrl

PyTorch Implementation of Spatially Consistent Representation Learning(SCRL)
Python
108
star
20

flame

Official implementation of the paper "FLAME: Free-form Language-based Motion Synthesis & Editing"
Python
103
star
21

brain-agent

Brain Agent for Large-Scale and Multi-Task Agent Learning
Python
92
star
22

helo-word

Team Kakao&Brain's Grammatical Error Correction System for the ACL 2019 BEA Shared Task
Python
88
star
23

jejueo

Jejueo Datasets for Machine Translation and Speech Synthesis
Python
74
star
24

solvent

Python
66
star
25

noc

Jupyter Notebook
44
star
26

cxr-clip

Python
43
star
27

expgan

Python
41
star
28

autowu

Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)
Python
39
star
29

nvs-adapter

Python
33
star
30

ginr-ipc

The official implementation of Generalizable Implicit Neural Representations with Instance Pattern Composers(CVPR’23 highlight).
Python
30
star
31

coyo-vit

ViT trained on COYO-Labeled-300M dataset
Python
28
star
32

irm-empirical-study

An Empirical Study of Invariant Risk Minimization
Python
28
star
33

coyo-align

ALIGN trained on COYO-dataset
Python
25
star
34

magvlt

The official implementation of MAGVLT: Masked Generative Vision-and-Language Transformer (CVPR'23)
Python
23
star
35

hqtransformer

Locally Hierarchical Auto-Regressive Modeling for Image Generation (HQ-Transformer)
Jupyter Notebook
21
star
36

CheXGPT

Python
18
star
37

learning-loss-for-tta

"Learning Loss for Test-Time Augmentation (NeurIPS 2020)"
Python
9
star
38

stg

Official implementation of Selective Token Generation (COLING'22)
Jupyter Notebook
8
star
39

leco

Official implementation of LECO (NeurIPS'22)
Python
6
star
40

bc-hyperopt-example

brain cloud hyperopt example (mnist)
Python
3
star