• Stars
    star
    7,832
  • Rank 4,838 (Top 0.1 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated 8 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch

Imagen - Pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.

Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.

It appears neither CLIP nor prior network is needed after all. And so research continues.

AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Shoutouts

  • StabilityAI for the generous sponsorship, as well as my other sponsors out there

  • šŸ¤— Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them

  • Jonathan Ho for bringing about a revolution in generative artificial intelligence through his seminal paper

  • Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training

  • Alex for einops, indispensable tool for tensor manipulation

  • Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version

  • Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion

  • Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging

  • Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets

  • Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results

  • MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts

  • Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix

  • BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time

  • Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image

  • Kay for contributing one line command training of Imagen!

  • Hadrien Reynaud for testing out text-to-video on a medical dataset, sharing his results, and identifying issues!

Install

$ pip install imagen-pytorch

Usage

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)

The number of textual captions must match the batch size of the images if you go this route.

# mock images and text (get a lot of this)

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, texts = texts, unet_number = i)
    loss.backward()

With the ImagenTrainer wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)

You can also train Imagen without text (unconditional image generation) as follows

import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer

# unets for unconditional imagen

unet1 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = False,
    use_linear_attn = True
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text = False,   # this must be set to False for unconditional Imagen
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet separately
# in this example, only training on unet number 1

loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)

Or train only super-resoluting unets

import torch
from imagen_pytorch import Unet, NullUnet, Imagen

# unet for imagen

unet1 = NullUnet()  # add a placeholder "null" unet for the base unet

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 250,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images

lowres_images = torch.randn(3, 3, 64, 64).cuda()  # starting un-resoluted images

images = imagen.sample(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles'
    ],
    start_at_unet_number = 2,              # start at unet number 2
    start_image_or_video = lowres_images,  # pass in low resolution images to be resoluted
    cond_scale = 3.)

images.shape # (3, 3, 256, 256)

At any time you can save and load the trainer and all associated states with the save and load methods. It is recommended you use these methods instead of manually saving with a state_dict call, as there are some device memory management being done underneath the hood within the trainer.

ex.

trainer.save('./path/to/checkpoint.pt')

trainer.load('./path/to/checkpoint.pt')

trainer.steps # (2,) step number for each of the unets, in this case 2

Dataloader

You can also rely on the ImagenTrainer to automatically train off DataLoader instances. You simply have to craft your DataLoader to return either images (for unconditional case), or of ('images', 'text_embeds') for text-guided generation.

ex. unconditional training

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset

# unets for unconditional imagen

unet = Unet(
    dim = 32,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 1,
    layer_attns = (False, False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unet above

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 128,
    timesteps = 1000
)

trainer = ImagenTrainer(
    imagen = imagen,
    split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()

# instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/path/to/training/images', image_size = 128)

trainer.add_train_dataset(dataset, batch_size = 16)

# working training loop

for i in range(200000):
    loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
        images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
        images[0].save(f'./sample-{i // 100}.png')

Multi GPU

Thanks to šŸ¤— Accelerate, you can do multi GPU training easily with two steps.

First you need to invoke accelerate config in the same directory as your training script (say it is named train.py)

$ accelerate config

Next, instead of calling python train.py as you would for single GPU, you would use the accelerate CLI as so

$ accelerate launch train.py

That's it!

Command-line

Imagen can also be used via CLI directly.

Configuration

ex.

$ imagen config

or

$ imagen config --path ./configs/config.json

In the config you are able to change settings for the trainer, dataset and the imagen config.

The Imagen config parameters can be found here

The Elucidated Imagen config parameters can be found here

The Imagen Trainer config parameters can be found here

For the dataset parameters all dataloader parameters can be used.

Training

This command allows you to train or resume training your model

ex.

$ imagen train

or

$ imagen train --unet 2 --epoches 10000 --valid 100

You can pass following arguments to the training command.

  • --config specify the config file to use for training [default: ./imagen_config.json]
  • --unet the index of the unet to train [default: 1]
  • --epoches how many epoches to train for [default: 1000]
  • --text specify the text to samples with for every 100 epoches
  • --valid enable validation and optionally specify how many epoches between each validation [default: false]

Sampling

Be aware when sampling your checkpoint should have trained all unets to get a usable result.

ex.

$ imagen sample --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
# image is saved to ./a_squirrel_raiding_the_birdfeeder.png

You can pass following arguments to the sample command.

  • --model specify the model file to use for sampling
  • --cond_scale conditioning scale (classifier free guidance) in decoder
  • --load_ema load EMA version of unets if available

In order to use a saved checkpoint with this feature, you either must instantiate your Imagen instance using the config classes, ImagenConfig and ElucidatedImagenConfig or create a checkpoint via the CLI directly

For proper training, you'll likely want to setup config-driven training anyways.

ex.

import torch
from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer

# in this example, using elucidated imagen

imagen = ElucidatedImagenConfig(
    unets = [
        dict(dim = 32, dim_mults = (1, 2, 4, 8)),
        dict(dim = 32, dim_mults = (1, 2, 4, 8))
    ],
    image_sizes = (64, 128),
    cond_drop_prob = 0.5,
    num_sample_steps = 32
).create()

trainer = ImagenTrainer(imagen)

# do your training ...

# then save it

trainer.save('./checkpoint.pt')

# you should see a message informing you that ./checkpoint.pt is commandable from the terminal

It really should be as simple as that

You can also pass this checkpoint file around, and anyone can continue finetune on their own data

from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer

imagen = load_imagen_from_checkpoint('./checkpoint.pt')

trainer = ImagenTrainer(imagen)

# continue training / fine-tuning

Inpainting

Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in inpaint_images and inpaint_masks to the sample function on either Imagen or ElucidatedImagen

inpaint_images = torch.randn(4, 3, 512, 512).cuda()      # (batch, channels, height, width)
inpaint_masks = torch.ones((4, 512, 512)).bool().cuda()  # (batch, height, width)

inpainted_images = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_images # (4, 3, 512, 512)

For video, similarly pass in your videos to inpaint_videos keyword on .sample. Inpainting mask can either be the same across all frames (batch, height, width) or different (batch, frames, height, width)

inpaint_videos = torch.randn(4, 3, 8, 512, 512).cuda()   # (batch, channels, frames, height, width)
inpaint_masks = torch.ones((4, 8, 512, 512)).bool().cuda()  # (batch, frames, height, width)

inpainted_videos = trainer.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
], inpaint_videos = inpaint_videos, inpaint_masks = inpaint_masks, cond_scale = 5.)

inpainted_videos # (4, 3, 8, 512, 512)

Experimental

Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of Imagen, the ElucidatedImagen, so that one can use the new elucidated DDPM for text-guided cascading generation.

Simply import ElucidatedImagen, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.

Ex.

from imagen_pytorch import ElucidatedImagen

# instantiate your unets ...

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    cond_drop_prob = 0.1,
    num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
    sigma_min = 0.002,           # min noise level
    sigma_max = (80, 160),       # max noise level, @crowsonkb recommends double the max noise level for upsampler
    sigma_data = 0.5,            # standard deviation of data distribution
    rho = 7,                     # controls the sampling schedule
    P_mean = -1.2,               # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                 # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# rest is the same as above

Text to Video

This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models

Update: verified working by Hadrien Reynaud!

Ex.

import torch
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer

unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()

# elucidated imagen, which contains the unets above (base unet and super resoluting ones)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (16, 32),
    random_crop_sizes = (None, 16),
    temporal_downsample_factor = (2, 1),        # in this example, the first unet would receive the video temporally downsampled by 2x
    num_sample_steps = 10,
    cond_drop_prob = 0.1,
    sigma_min = 0.002,                          # min noise level
    sigma_max = (80, 160),                      # max noise level, double the max noise level for upsampler
    sigma_data = 0.5,                           # standard deviation of data distribution
    rho = 7,                                    # controls the sampling schedule
    P_mean = -1.2,                              # mean of log-normal distribution from which noise is drawn for training
    P_std = 1.2,                                # standard deviation of log-normal distribution from which noise is drawn for training
    S_churn = 80,                               # parameters for stochastic sampling - depends on dataset, Table 5 in apper
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
).cuda()

# mock videos (get a lot of this) and text encodings from large T5

texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'dust motes swirling in the morning sunshine on the windowsill'
]

videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)

# feed images into imagen, training each unet in the cascade
# for this example, only training unet 1

trainer = ImagenTrainer(imagen)

# you can also ignore time when training on video initially, shown to improve results in video-ddpm paper. eventually will make the 3d unet trainable with either images or video. research shows it is essential (with current data regimes) to train first on text-to-image. probably won't be true in another decade. all big data becomes small data

trainer(videos, texts = texts, unet_number = 1, ignore_time = False)
trainer.update(unet_number = 1)

videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames

videos.shape # (4, 3, 20, 32, 32)

You can also train on text - image pairs first. The Unet3D will automatically convert it to single framed videos and learn without the temporal components (by automatically setting ignore_time = True), whether it be 1d convolutions or causal attention across time.

This is the current approach taken by all the big artificial intelligence labs (Brain, MetaAI, Bytedance)

FAQ

  • Why are my generated images not aligning well with the text?

Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0.

Researcher Netruk44 have reported 5-10 to be optimal, but anything greater than 10 to break.

trainer.sample(texts = [
    'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
  • Are there any pretrained models yet?

Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.

  • Will this technology take my job?

More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.

  • What am I allowed to do with this repository?

Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.

Related Works

Todo

  • use huggingface transformers for T5-small text embeddings

  • add dynamic thresholding

  • add dynamic thresholding DALLE2 and video-diffusion repository as well

  • allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)

  • add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time

  • port over some training code from DALLE2

  • need to be able to use a different noise schedule per unet (cosine was used for base, but linear for SR)

  • just make one master-configurable unet

  • complete resnet block (biggan inspired? but with groupnorm) - complete self attention

  • complete conditioning embedding block (and make it completely configurable, whether it be attention, film etc)

  • consider using perceiver-resampler from https://github.com/lucidrains/flamingo-pytorch in place of attention pooling

  • add attention pooling option, in addition to cross attention and film

  • add optional cosine decay schedule with warmup, for each unet, to trainer

  • switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages - first figure out the linear noise schedule case from the variational ddpm paper https://openreview.net/forum?id=2LdBqxc1Yv

  • figure out log(snr) for alpha cosine noise schedule.

  • suppress the transformers warning because only T5encoder is used

  • allow setting for using linear attention on layers where full attention cannot be used

  • force unets in continuous time case to use non-fouriered conditions (just pass the log(snr) through an MLP with optional layernorms), as that is what i have working locally

  • removed learned variance

  • add p2 loss weighting for continuous time

  • make sure cascading ddpm can be trained without text condition, and make sure both continuous and discrete time gaussian diffusion works

  • use primer's depthwise convs on the qkv projections in linear attention (or use token shifting before projections) - also use new dropout proposed by bayesformer, as it seems to work well with linear attention

  • explore skip layer excitation in unet decoder

  • accelerate integration

  • build out CLI tool and one-line generation of image

  • knock out any issues that arised from accelerate

  • add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865

  • build a simple checkpointing system, backed by a folder

  • add skip connection from outputs of all upsample blocks, used in unet squared paper and some previous unet works

  • add fsspec, recommended by Romain @rom1504, for cloud / local file system agnostic persistence of checkpoints

  • test out persistence in gcs with https://github.com/fsspec/gcsfs

  • extend to video generation, using axial time attention as in Ho's video ddpm paper

  • allow elucidated imagen to generalize to any shape

  • allow for imagen to generalize to any shape

  • add dynamic positional bias for the best type of length extrapolation across video time

  • move video frames to sample function, as we will be attempting time extrapolation

  • attention bias to null key / values should be a learned scalar of head dimension

  • add self-conditioning from bit diffusion paper, already coded up at ddpm-pytorch

  • add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new

  • incorporate all learnings from make-a-video (https://makeavideo.studio/)

  • build out CLI tool for training, resuming training off config file

  • allow for temporal interpolation at specific stages

  • make sure temporal interpolation works with inpainting

  • make sure one can customize all interpolation modes (some researchers are finding better results with trilinear)

  • imagen-video : allow for conditioning on preceding (and possibly future) frames of videos. ignore time should not be allowed in that scenario

  • make sure to automatically take care of temporal down/upsampling for conditioning video frames, but allow for an option to turn it off

  • make sure inpainting works with video

  • make sure inpainting mask for video can accept be customized per frame

  • add flash attention

  • reread cogvideo and figure out how frame rate conditioning could be used

  • bring in attention expertise for self attention layers in unet3d

  • consider bringing in NUWA's 3d convolutional attention

  • consider transformer-xl memories in the temporal attention blocks

  • consider perceiver-ar approach to attending to past time

  • frame dropouts during attention for achieving both regularizing effect as well as shortened training time

  • investigate frank wood's claims https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch and either add the hierarchical sampling technique, or let people know about its deficiencies

  • offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video

  • preencoding of text to memmapped embeddings

  • be able to create dataloader iterators based on the old epoch style, also configure shuffling etc

  • be able to also pass in arguments (instead of requiring forward to be all keyword args on model)

  • bring in reversible blocks from revnets for 3d unet, to lessen memory burden

  • add ability to only train super-resolution network

  • read dpm-solver see if it is applicable to continuous time gaussian diffusion

  • allow for conditioning video frames with arbitrary absolute times (calculate RPE during temporal attention)

  • accommodate dream booth fine tuning

  • add textual inversion

  • cleanup self conditioning to be extracted at imagen instantiation

  • make sure eventual dreambooth works with imagen-video

  • add framerate conditioning for video diffusion

  • make sure one can simulataneously condition on video frames as a prompt, as well as some conditioning image across all frames

  • test and add distillation technique from consistency models

Citations

@inproceedings{Saharia2022PhotorealisticTD,
    title   = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
    author  = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
    year    = {2022}
}
@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Sankararaman2022BayesFormerTW,
    title   = {BayesFormer: Transformer with Uncertainty Estimation},
    author  = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
    year    = {2022}
}
@article{So2021PrimerSF,
    title   = {Primer: Searching for Efficient Transformers for Language Modeling},
    author  = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2109.08668}
}
@misc{cao2020global,
    title   = {Global Context Networks},
    author  = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year    = {2020},
    eprint  = {2012.13375},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@article{Karras2022ElucidatingTD,
    title   = {Elucidating the Design Space of Diffusion-Based Generative Models},
    author  = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2206.00364}
}
@inproceedings{NEURIPS2020_4c5bcfec,
    author      = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
    booktitle   = {Advances in Neural Information Processing Systems},
    editor      = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
    pages       = {6840--6851},
    publisher   = {Curran Associates, Inc.},
    title       = {Denoising Diffusion Probabilistic Models},
    url         = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
    volume      = {33},
    year        = {2020}
}
@article{Lugmayr2022RePaintIU,
    title   = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
    author  = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2201.09865}
}
@misc{ho2022video,
    title   = {Video Diffusion Models},
    author  = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
    year    = {2022},
    eprint  = {2204.03458},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{chen2022analog,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
    year    = {2022},
    eprint  = {2208.04202},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{Singer2022,
    author  = {Uriel Singer},
    url     = {https://makeavideo.studio/Make-A-Video.pdf}
}
@article{Sunkara2022NoMS,
    title   = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
    author  = {Raja Sunkara and Tie Luo},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.03641}
}
@article{Salimans2022ProgressiveDF,
    title   = {Progressive Distillation for Fast Sampling of Diffusion Models},
    author  = {Tim Salimans and Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.00512}
}
@article{Ho2022ImagenVH,
    title   = {Imagen Video: High Definition Video Generation with Diffusion Models},
    author  = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.02303}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
@inproceedings{Hang2023EfficientDT,
    title   = {Efficient Diffusion Training via Min-SNR Weighting Strategy},
    author  = {Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
    year    = {2023}
}
@article{Zhang2021TokenST,
    title   = {Token Shift Transformer for Video Classification},
    author  = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
    journal = {Proceedings of the 29th ACM International Conference on Multimedia},
    year    = {2021}
}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}

More Repositories

1

vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Python
13,633
star
2

DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Python
11,068
star
3

PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
Python
7,611
star
4

DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
Python
5,132
star
5

deep-daze

Simple command line tool for text to image generation using OpenAI's CLIP and Siren (Implicit neural representation network). Technique was originally created by https://twitter.com/advadnoun
Python
4,387
star
6

denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
Python
3,959
star
7

stylegan2-pytorch

Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Python
3,433
star
8

musiclm-pytorch

Implementation of MusicLM, Google's new SOTA model for music generation using attention networks, in Pytorch
Python
3,048
star
9

x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
Python
2,707
star
10

big-sleep

A simple command line tool for text to image generation, using OpenAI's CLIP and a BigGAN. Technique was originally created by https://twitter.com/advadnoun
Python
2,446
star
11

audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
Python
2,285
star
12

lion-pytorch

šŸ¦ Lion, new optimizer discovered by Google Brain using genetic algorithms that is purportedly better than Adam(w), in Pytorch
Python
1,933
star
13

toolformer-pytorch

Implementation of Toolformer, Language Models That Can Use Tools, by MetaAI
Python
1,905
star
14

reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
Python
1,870
star
15

make-a-video-pytorch

Implementation of Make-A-Video, new SOTA text to video generator from Meta AI, in Pytorch
Python
1,853
star
16

gigagan-pytorch

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs
Python
1,632
star
17

alphafold2

To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Python
1,536
star
18

lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
Python
1,526
star
19

lambda-networks

Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Python
1,516
star
20

byol-pytorch

Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch
Python
1,497
star
21

self-rewarding-lm-pytorch

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI
Python
1,253
star
22

naturalspeech2-pytorch

Implementation of Natural Speech 2, Zero-shot Speech and Singing Synthesizer, in Pytorch
Python
1,214
star
23

flamingo-pytorch

Implementation of šŸ¦© Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Python
1,155
star
24

video-diffusion-pytorch

Implementation of Video Diffusion Models, Jonathan Ho's new paper extending DDPMs to Video Generation - in Pytorch
Python
1,141
star
25

soundstorm-pytorch

Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch
Python
1,130
star
26

CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
Python
990
star
27

performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
Python
937
star
28

perceiver-pytorch

Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Python
935
star
29

RETRO-pytorch

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
Python
835
star
30

mlp-mixer-pytorch

An All-MLP solution for Vision, from Google AI
Python
833
star
31

muse-maskgit-pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch
Python
821
star
32

PaLM-pytorch

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways
Python
812
star
33

vector-quantize-pytorch

Vector Quantization, in Pytorch
Python
810
star
34

phenaki-pytorch

Implementation of Phenaki Video, which uses Mask GIT to produce text guided videos of up to 2 minutes in length, in Pytorch
Python
724
star
35

x-clip

A concise but complete implementation of CLIP with various experimental improvements from recent papers
Python
658
star
36

bottleneck-transformer-pytorch

Implementation of Bottleneck Transformer in Pytorch
Python
632
star
37

memorizing-transformers-pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
Python
614
star
38

TimeSformer-pytorch

Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Python
613
star
39

MEGABYTE-pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
Python
594
star
40

meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
Python
564
star
41

nuwa-pytorch

Implementation of NƜWA, state of the art attention network for text to video synthesis, in Pytorch
Python
531
star
42

voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
Python
521
star
43

point-transformer-pytorch

Implementation of the Point Transformer layer, in Pytorch
Python
518
star
44

parti-pytorch

Implementation of Parti, Google's pure attention-based text-to-image neural network, in Pytorch
Python
509
star
45

tab-transformer-pytorch

Implementation of TabTransformer, attention network for tabular data, in Pytorch
Python
485
star
46

alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
Python
483
star
47

linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
Python
468
star
48

magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
Python
436
star
49

ema-pytorch

A simple way to keep track of an Exponential Moving Average (EMA) version of your pytorch model
Python
408
star
50

egnn-pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch
Python
400
star
51

g-mlp-pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Python
391
star
52

recurrent-memory-transformer-pytorch

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch
Python
384
star
53

ring-attention-pytorch

Implementation of šŸ’ Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
Python
380
star
54

siren-pytorch

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Python
377
star
55

enformer-pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch
Python
352
star
56

iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
Python
349
star
57

robotic-transformer-pytorch

Implementation of RT1 (Robotic Transformer) in Pytorch
Python
346
star
58

memory-efficient-attention-pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(nĀ²) Memory"
Python
342
star
59

FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
Python
334
star
60

bit-diffusion

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch
Python
313
star
61

medical-chatgpt

Implementation of ChatGPT, but tailored towards primary care medicine, with the reward being able to collect patient histories in a thorough and efficient manner and come up with a reasonable differential diagnosis
Python
311
star
62

slot-attention

Implementation of Slot Attention from GoogleAI
Python
303
star
63

q-transformer

Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
Python
293
star
64

BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
Python
289
star
65

classifier-free-guidance-pytorch

Implementation of Classifier Free Guidance in Pytorch, with emphasis on text conditioning, and flexibility to include multiple text embedding models
Python
282
star
66

transformer-in-transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Python
277
star
67

axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
Python
273
star
68

conformer

Implementation of the convolutional module from the Conformer paper, for use in Transformers
Python
272
star
69

mixture-of-experts

A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Python
264
star
70

deformable-attention

Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"
Python
258
star
71

magic3d-pytorch

Implementation of Magic3D, Text to 3D content synthesis, in Pytorch
Python
258
star
72

x-unet

Implementation of a U-net complete with efficient attention as well as the latest research findings
Python
252
star
73

routing-transformer

Fully featured implementation of Routing Transformer
Python
251
star
74

Adan-pytorch

Implementation of the Adan (ADAptive Nesterov momentum algorithm) Optimizer in Pytorch
Python
245
star
75

spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
Python
241
star
76

st-moe-pytorch

Implementation of ST-Moe, the latest incarnation of MoE after years of research at Brain, in Pytorch
Python
237
star
77

perfusion-pytorch

Implementation of Key-Locked Rank One Editing, from Nvidia AI
Python
229
star
78

equiformer-pytorch

Implementation of the Equiformer, SE3/E3 equivariant attention network that reaches new SOTA, and adopted for use by EquiFold for protein folding
Python
227
star
79

segformer-pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Python
227
star
80

sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
Python
222
star
81

pixel-level-contrastive-learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch
Python
220
star
82

lumiere-pytorch

Implementation of Lumiere, SOTA text-to-video generation from Google Deepmind, in Pytorch
Python
216
star
83

local-attention

An implementation of local windowed attention for language modeling
Python
216
star
84

CoLT5-attention

Implementation of the conditionally routed attention in the CoLT5 architecture, in Pytorch
Python
216
star
85

natural-speech-pytorch

Implementation of the neural network proposed in Natural Speech, a text-to-speech generator that is indistinguishable from human recordings for the first time, from Microsoft Research
Python
215
star
86

soft-moe-pytorch

Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch
Python
211
star
87

se3-transformer-pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Python
211
star
88

block-recurrent-transformer-pytorch

Implementation of Block Recurrent Transformer - Pytorch
Python
205
star
89

Mega-pytorch

Implementation of Mega, the Single-head Attention with Multi-headed EMA architecture that currently holds SOTA on Long Range Arena
Python
201
star
90

simple-hierarchical-transformer

Experiments around a simple idea for inducing multiple hierarchical predictive model within a GPT
Python
198
star
91

med-seg-diff-pytorch

Implementation of MedSegDiff in Pytorch - SOTA medical segmentation using DDPM and filtering of features in fourier space
Python
195
star
92

triton-transformer

Implementation of a Transformer, but completely in Triton
Python
195
star
93

jax2torch

Use Jax functions in Pytorch
Python
194
star
94

flash-cosine-sim-attention

Implementation of fused cosine similarity attention in the same style as Flash Attention
Cuda
194
star
95

halonet-pytorch

Implementation of the šŸ˜‡ Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Python
193
star
96

attention

This repository will house a visualization that will attempt to convey instant enlightenment of how Attention works to someone not working in artificial intelligence, with 3Blue1Brown as inspiration
HTML
189
star
97

recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
Python
188
star
98

electra-pytorch

A simple and working implementation of Electra, the fastest way to pretrain language models from scratch, in Pytorch
Python
186
star
99

PaLM-jax

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Python
184
star
100

unet-stylegan2

A Pytorch implementation of Stylegan2 with UNet Discriminator
Python
182
star