• Stars
    star
    434
  • Rank 100,274 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 4 years ago
  • Updated almost 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Open-AI's DALL-E for large scale training in mesh-tensorflow.

DALL-E in Mesh-Tensorflow [WIP]

Open-AI's DALL-E in Mesh-Tensorflow.

If this is similarly efficient to GPT-Neo, this repo should be able to train models up to, and larger than, the size of Open-AI's DALL-E (12B params).

No pretrained models... Yet.

Thanks to Ben Wang for the tf vae implementation as well as getting the mtf version working, and Aran Komatsuzaki for help building the mtf VAE and input pipeline.

Setup

git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt

Training Setup

Runs on TPUs, untested on GPUs but should work in theory. The example configs are designed to run on a TPU v3-32 pod.

To set up TPUs, sign up for Google Cloud Platform, and create a storage bucket.

Create your VM through a google shell (https://ssh.cloud.google.com/) with ctpu up --vm-only so that it can connect to your Google bucket and TPUs and setup the repo as above.

VAE pretraining

DALLE needs a pretrained VAE to compress images to tokens. To run the VAE pretraining, adjust the params in configs/vae_example.json to a glob path pointing to a dataset of jpgs, and adjust image size to the appropriate size.

  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg",
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg",
    "image_size": 32
  }

Once this is all set up, create your TPU, then run:

python train_vae_tf.py --tpu your_tpu_name --model vae_example

The training logs image tensors and loss values, to check progress, you can run:

tensorboard --logdir your_model_dir

Dataset Creation [DALL-E]

Once the VAE is pretrained, you can move on to DALL-E.

Currently we are training on a dummy dataset. A public, large-scale dataset for DALL-E is in the works. In the meantime, to generate some dummy data, run:

python src/data/create_tfrecords.py

This should download CIFAR-10, and generate some random captions to act as text inputs.

Custom datasets should be formatted in a folder, with a jsonl file in the root folder containing caption data and paths to the respective images, as follows:

Folder structure:

        data_folder
            jsonl_file
            folder_1
                img1
                img2
                ...
            folder_2
                img1
                img2
                ...
            ...

jsonl structure:
    {"image_path": folder_1/img1, "caption": "some words"}
    {"image_path": folder_2/img2, "caption": "more words"}
    ...

you can then use the create_paired_dataset function in src/data/create_tfrecords.py to encode the dataset into tfrecords for use in training.

Once the dataset is created, copy it over to your bucket with gsutil:

gsutil cp -r DALLE-tfrecords gs://neo-datasets/

And finally, run training with

python train_dalle.py --tpu your_tpu_name --model dalle_example

Config Guide

VAE:

{
  "model_type": "vae",
  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg", # glob path to training images
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg", # glob path to eval images
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, 
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000, # how often to save a checkpoint
  "iterations": 500, # number of batches to infeed to the tpu at a time. Must be < steps_per_checkpoint
  "train_steps": 100000, # total training steps
  "eval_steps": 0, # run evaluation for this many steps every steps_per_checkpoint
  "model_path": "gs://neo-models/vae_test2/", # directory in which to save the model
  "mesh_shape": "data:16,model:2", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info
  "layout": "batch_dim:data", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info
  "num_tokens": 512, # vocab size
  "dim": 512, 
  "hidden_dim": 64, # size of hidden dim
  "n_channels": 3, # number of input channels
  "bf_16": false, # if true, the model is trained with bfloat16 precision
  "lr": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]
  "num_layers": 3, # number of blocks in the encoder / decoder
  "train_gumbel_hard": true, # whether to use hard or soft gumbel_softmax
  "eval_gumbel_hard": true
}

DALL-E:

{
  "model_type": "dalle",
  "dataset": {
    "train_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords", # glob path to tfrecords data
    "eval_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords",
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, # see above
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000,
  "iterations": 500,
  "train_steps": 100000,
  "predict_steps": 0,
  "eval_steps": 0,
  "n_channels": 3,
  "bf_16": false,
  "lr": 0.001,
  "model_path": "gs://neo-models/dalle_test/",
  "mesh_shape": "data:16,model:2",
  "layout": "batch_dim:data",
  "n_embd": 512, # size of embedding dim
  "text_vocab_size": 50258, # vocabulary size of the text tokenizer
  "image_vocab_size": 512, # vocabulary size of the vae - should equal num_tokens above
  "text_seq_len": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)
  "n_layers": 6, 
  "n_heads": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128
  "vae_model": "vae_example" # path to or name of vae model config
}

More Repositories

1

gpt-neo

An implementation of model parallel GPT-2 and GPT-3-style models using the mesh-tensorflow library.
Python
8,224
star
2

gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
Python
6,829
star
3

lm-evaluation-harness

A framework for few-shot evaluation of language models.
Python
6,268
star
4

pythia

The hub for EleutherAI's work on interpretability and learning dynamics
Jupyter Notebook
2,193
star
5

the-pile

Python
1,459
star
6

math-lm

Python
1,035
star
7

cookbook

Deep learning for dummies. All the practical details and useful utilities that go into working with real models.
Python
635
star
8

polyglot

Polyglot: Large Language Models of Well-balanced Competence in Multi-languages
471
star
9

vqgan-clip

Jupyter Notebook
345
star
10

sae

Sparse autoencoders
Python
274
star
11

concept-erasure

Erasing concepts from neural representations with provable guarantees
Python
207
star
12

elk

Keeping language models honest by directly eliciting knowledge encoded in their activations.
Python
186
star
13

oslo

OSLO: Open Source for Large-scale Optimization
Python
173
star
14

lm_perplexity

Python
144
star
15

knowledge-neurons

A library for finding knowledge neurons in pretrained transformer models.
Python
142
star
16

pyfra

Python Research Framework
Python
107
star
17

dps

Data processing system for polyglot
Python
88
star
18

openwebtext2

Python
86
star
19

info

(Deprecated) A hub for onboarding & other information.
78
star
20

improved-t5

Experiments for efforts to train a new and improved t5
Python
76
star
21

stackexchange-dataset

Python tools for processing the stackexchange data dumps into a text dataset for Language Models
Python
73
star
22

project-menu

See the issue board for the current status of active and prospective projects!
65
star
23

magiCARP

One stop shop for all things carp
Python
58
star
24

sae-auto-interp

Python
53
star
25

semantic-memorization

Jupyter Notebook
44
star
26

tqdm-multiprocess

Using queues, tqdm-multiprocess supports multiple worker processes, each with multiple tqdm progress bars, displaying them cleanly through the main process. It offers similar functionality for python logging.
Python
41
star
27

aria

Python
37
star
28

hae-rae

32
star
29

rnngineering

Engineering the state of RNN language models (Mamba, RWKV, etc.)
Jupyter Notebook
31
star
30

features-across-time

Understanding how features learned by neural networks evolve throughout training
Python
30
star
31

mp_nerf

Massively-Parallel Natural Extension of Reference Frame
Jupyter Notebook
29
star
32

elk-generalization

Investigating the generalization behavior of LM probes trained to predict truth labels: (1) from one annotator to another, and (2) from easy questions to hard
Python
23
star
33

pile-pubmedcentral

A script for collecting the PubMed Central dataset in a language modelling friendly format.
Python
22
star
34

best-download

URL downloader supporting checkpointing and continuous checksumming.
Python
19
star
35

polyglot-data

data related codebase for polyglot project
Python
19
star
36

aria-amt

Efficient and robust implementation of seq-to-seq automatic piano transcription.
Python
18
star
37

text-generation-testing-ui

Web app for demoing the EAI models
JavaScript
16
star
38

exploring-contrastive-topology

Jupyter Notebook
16
star
39

mdl

Minimum Description Length probing for neural network representations
Python
15
star
40

pile_dedupe

Pile Deduplication Code
Python
15
star
41

w2s

Python
15
star
42

pilev2

Python
13
star
43

distilling

Experiments with distilling large language models.
Python
13
star
44

tokengrams

Efficiently computing & storing token n-grams from large corpora
Rust
13
star
45

lm-eval2

Python
11
star
46

equivariance

A framework for implementing equivariant DL
Jupyter Notebook
10
star
47

radioactive-lab

Adapting the "Radioactive Data" paper to work for text models
Python
9
star
48

pile-literotica

Download, parse, and filter data from Literotica. Data-ready for The-Pile.
Python
8
star
49

hn-scraper

Python
8
star
50

tagged-pile

Part-of-Speech Tagging for the Pile and RedPajama
Python
8
star
51

multimodal-fid

Python
7
star
52

pile-uspto

A script for collecting the USPTO Backgrounds dataset in a language modelling friendly format.
Python
7
star
53

pile-cc-filtering

The code used to filter CC data for The Pile
Python
6
star
54

minetest-baselines

Baseline agents for Minetest tasks.
Python
6
star
55

CodeCARP

Data collection pipeline for CodeCARP. Includes PyCharm plugins.
6
star
56

pile-enron-emails

A script for collecting the Enron Emails dataset in a language modelling friendly format.
Python
6
star
57

pile-explorer

For exploring the data and documenting its limitations
Python
5
star
58

minetest-interpretabilty-notebook

Jupyter notebook for the interpretablity section of the minetester blog post
Jupyter Notebook
5
star
59

thonkenizers

yes
5
star
60

eleutherai.github.io

This is the Hugo generated website for eleuther.ai. The source of this build is new-website repo.
HTML
5
star
61

visual-grounding

Visually ground GPT-Neo 1.3b and 2.7b
Python
5
star
62

LLM-Markov-Chains

Project github for LLM Markov Chains Project
5
star
63

architecture-experiments

Repository to host architecture experiments and development using Paxml and Praxis
Python
5
star
64

llemma-sample-explorer

Sample explorer tool for the Llemma models.
HTML
5
star
65

lm-scope

Jupyter Notebook
4
star
66

latent-video-diffusion

Latent video diffusion
Python
4
star
67

megatron-3d

Python
4
star
68

website

New website for EleutherAI based on Hugo static site generator
HTML
4
star
69

Unpaired-Image-Generation

Project Repo for Unpaired Image Generation project
4
star
70

ccs

Python
4
star
71

isaac-mchorse

EleutherAI's discord bot
Python
3
star
72

pile-allpoetry

Scraper to gather poems from allpoetry.com
Python
3
star
73

EvilModel

A replication of "EvilModel 2.0: Bringing Neural Network Models into Malware Attacks"
3
star
74

eai-prompt-gallery

Library of interesting prompt generations
JavaScript
3
star
75

variance-across-time

Studying the variance in neural net predictions across training time
Python
3
star
76

pile-ubuntu-irc

A script for collecting the Ubuntu IRC dataset in a language modelling friendly format.
Python
3
star
77

reddit-comment-processing

Python
2
star
78

eleutherai-instruct-dataset

A large instruct dataset for open-source models (WIP).
2
star
79

bucket-cleaner

A small utility to clear out old model checkpoints in Google Cloud Buckets whilst keeping tensorboard event files
Python
2
star
80

groupoid-rl

Jupyter Notebook
2
star
81

equinox-llama

Equinox implementation of llama3 and llama3.1
Python
2
star
82

optax-galore

Adds GaLore style projection wrappers to optax optimizers
Python
2
star
83

lang-filter

Filter text files or archives by language
Python
1
star
84

eleuther-blog

here is the generated content for the EleutherAI blog. Source is from new-website repo
HTML
1
star
85

prefix-free-tokenizer

A prefix free tokenizer
Python
1
star
86

alignment-reader

Search and filter through alignment literature
JavaScript
1
star
87

grouch

HTML
1
star
88

language-adaptation

1
star
89

perceptors

central location for access to pretrained models for CLIP and variants, with common API and out-of-the-box differentiable weighted multi-perceptor
1
star
90

pd-books

Jupyter Notebook
1
star
91

classifier-latent-diffusion

Python
1
star
92

common-llm-settings

Common LLM Settings App
JavaScript
1
star
93

bayesian-adam

Exactly what it says on the tin
Python
1
star
94

pile-cord19

A script for collecting the CORD-19 dataset in a language modelling friendly format.
Python
1
star
95

conceptual-constraints

Applying LEACE to models during training
Jupyter Notebook
1
star
96

ngrams-across-time

Jupyter Notebook
1
star
97

steering-llama3

Python
1
star
98

truncated-gaussian

Method-of-moments estimation and sampling for truncated multivariate Gaussian distributions
Python
1
star