• Stars
    star
    134
  • Rank 270,967 (Top 6 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 4 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

DLAS - A configuration-driven trainer for generative models

Deep Learning Art School

Send your Pytorch model to art class!

This repository is both a framework and a set of tools for training deep neural networks that create images. It started as a branch of the open-mmlab project developed by Multimedia Laboratory, CUHK but has been almost completely re-written at every level.

Why do we need another training framework

These are a dime a dozen, no doubt. DL Art School (DLAS) differentiates itself by being configuration driven. You write the model code (specifically, a torch.nn.Module) and (possibly) some losses, then you cobble together a config file written in yaml that tells DLAS how to train it. Swapping model architectures and tuning hyper-parameters is simple and often requires no changes to actual code. You also don't need to remember complex command line incantations. This effectively enables you to run multiple concurrent experiments that use the same codebase, as well as retain backwards compatibility for past experiments.

Training effective generators often means juggling multiple loss functions. As a result, DLAS' configuration language is specifically designed to make it easy to support large number of losses and networks that interact with each other. As an example: some GANs I have trained in this framework consist of more than 15 losses and use 2 separate discriminators and require no bespoke code.

Generators are also notorious GPU memory hogs. I have spent substantial time streamlining the training framework to support gradient checkpointing and FP16. DLAS also supports "mega batching", where multiple forward passes contribute to a single backward pass. Most models can be trained on midrange GPUs with 8-11GB of memory.

The final value-added feature is interpretability. Tensorboard logging operates out of the box with no custom code. Intermediate images from within the training pipeline can be intermittently surfaced as normal PNG files so you can see what your network is up to. Validation passes are also cached as images so you can view how your network improves over time.

Modeling Capabilities

DLAS was built with extensibility in mind. One of the reasons I'm putting in the effort to better document this code is the incredible ease with which I have been able to train entirely new model types with no changes to the core training code.

I intend to fill out the sections below with sample configurations which can be used to train different architectures. You will need to bring your own data.

Super-resolution

Style Transfer

  • Stylegan2 (documentation TBC)

Latent development

  • BYOL
  • iGPT (documentation TBC)

Dependencies and Installation

  • Python 3
  • PyTorch >= 1.6
  • NVIDIA GPU + CUDA
  • Python packages: pip install -r requirements.txt
  • Some video utilities require FFMPEG

User Guide

TBC

Development Environment

If you aren't already using Pycharm - now is the time to try it out. This project was built in Pycharm and comes with an IDEA project for you to get started with. I've done all of my development on this repo in this IDE and lean heavily on its incredible debugger. It's free. Try it out. You won't be sorry.

Dataset Preparation

DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset:

  1. Create a Dataset in codes/data/ which takes a single Python dict as a constructor and extracts options from that dict.
  2. Register your Dataset in codes/data/init.py
  3. Your Dataset should return a dict of tensors. The keys of the dict are injected directly into the training state, which you can interact within your configuration file.

Training and Testing

There are currently 3 base scripts for interacting with models. They all take a single parameter, -opt which specifies the configuration file which controls how they work. Configs (will be) documented above in the user guide.

train.py

Start (or continue) a training session: python train.py -opt <your_config.yml>

Start a distributed training session: python -m torch.distributed.launch --nproc_per_node=<gpus> --master_port=1234 train.py -o <opt> --launcher=pytorch

test.py

Runs a model against a validation or test set of data and reports metrics (for now, just PSNR and a custom perceptual metric) python test.py -opt <your_config.yml>

process_video.py

Breaks a video into individual frames and uses a network to do processing on it, then reassembles the output back into video form. python process_video -opt <your_config.yml>

Contributing

At this time I am not taking feature requests or bug reports, but I appreciate all contributions.

License

This project is released under the Apache 2.0 license.

More Repositories

1

tortoise-tts

A multi-voice TTS system trained with an emphasis on quality
Jupyter Notebook
12,761
star
2

ocotillo

Performant and accurate speech recognition built on Pytorch
Python
242
star
3

tts-scores

Scripts for computing the Intelligibility and CLVP scores for evaluating TTS models
Python
129
star
4

BigListOfPodcasts

A list of podcast URLs scraped from the Apple podcast database in late 2021, including a script for downloading those podcasts.
Python
32
star
5

pyfastmp3decoder

A fast MP3 decoder for python, using minimp3
Cython
25
star
6

RaspPiArinc429

ARINC429 Driver Code for Raspberry Pi
Java
19
star
7

conveyer

A better data loading pipeline for training ML models
Python
9
star
8

mp_transformers

Implementation of an activation magnitude preserving transformer
Python
8
star
9

SwitchedConvolutions

A trainable layer that switches how ML blocks operate on images based on the contents of those images at the pixel level.
Python
5
star
10

audio_clip_processing_pipeline

Audio Clips Processing Pipeline
Python
5
star
11

transformers-tokenizer-java

A Java string tokenizer compatible with the popular huggingface transformers library
Java
3
star
12

JavaNI

Java Extensions and Gesture Recognition Sitting on OpenNI
C
3
star
13

fluvial

Awesome human photo super-resolution
Python
3
star
14

x-transformers-prod

A fork of x-transformers with modifications to make it suitable for production use
Python
3
star
15

spectracular

A high-quality neural spatial compression and decompression suite for music
2
star
16

torch-distributed-bench

Bench test torch.distributed
Python
2
star
17

MAV-Downlink

A MAVLink Interface App for Android Smartphones
Java
2
star
18

MAVDownlinkServer

Provides a server interface for the MAV Downlink Android Application
Java
2
star
19

quartz

An ultra-high compression voice quantizer
1
star
20

NonIntNLP

Non-interactive NLP - State of the art NLP for the masses
Python
1
star
21

tobii-mouse-winforms

C# application which accesses the Tobii StreamEngine API to provide mouse emulation functions.
C#
1
star