• Stars
    star
    171
  • Rank 222,266 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 1 year ago
  • Updated 9 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Diffusion Reinforcement Learning Library

Diffusion Reinforcement Learning X

DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's Diffusers library and uses Accelerate for Multi-GPU and Multi-Node (as of yet untested)

News (09/27/2023): Check out our blog post with some recent experiments here!

📖 Documentation

Setup

First make sure you've installed OpenCLIP. Afterwards, you can install the library from pypi:

pip install drlx

or from source:

pip install git+https://github.com/CarperAI/DRLX.git

How to use

Currently we have only tested the library with Stable Diffusion 1.4, 1.5, and 2.1, but the plug and play nature of it means that realistically any denoiser from most pipelines should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is DDPO.

from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig

# We import a reward model, a prompt pipeline, the trainer and config

pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)

trainer.train(pipe, Aesthetics())

And then to use a trained model for inference:

pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")

Accelerated Training

accelerate config
accelerate launch -m [your module]

Roadmap

  • Initial launch and DDPO
  • PickScore Tuned Models
  • DPO
  • SDXL support