• Stars
    star
    541
  • Rank 82,114 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 2 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Erasing Concepts from Diffusion Models

Erasing Concepts from Diffusion Models

Project Website | Arxiv Preprint | Fine-tuned Weights | Demo

Motivated by recent advancements in text-to-image diffusion, we study erasure of specific concepts from the model's weights. While Stable Diffusion has shown promise in producing explicit or realistic artwork, it has raised concerns regarding its potential for misuse. We propose a fine-tuning method that can erase a visual concept from a pre-trained diffusion model, given only the name of the style and using negative guidance as a teacher. We benchmark our method against previous approaches that remove sexually explicit content and demonstrate its effectiveness, performing on par with Safe Latent Diffusion and censored training.

To evaluate artistic style removal, we conduct experiments erasing five modern artists from the network and conduct a user study to assess the human perception of the removed styles. Unlike previous methods, our approach can remove concepts from a diffusion model permanently rather than modifying the output at the inference time, so it cannot be circumvented even if a user has access to model weights

Given only a short text description of an undesired visual concept and no additional data, our method fine-tunes model weights to erase the targeted concept. Our method can avoid NSFW content, stop imitation of a specific artist's style, or even erase a whole object class from model output, while preserving the model's behavior and capabilities on other topics.

Fine-tuned Weights

The finetuned weights for both NSFW and art style erasures are available on our project page.

Running Gradio Demo Locally

To run the gradio interactive demo locally, clone the files from demo repository

  • Create an environment using the packages included in the requirements.txt file
  • Run python app.py
  • Open the application in browser at http://127.0.0.1:7860/
  • Train, evaluate, and save models using our method

Installation Guide

  • To get started clone the following repository of Original Stable Diffusion Link
  • Then download the files from our repository to stable-diffusion main directory of stable diffusion. This would replace the ldm folder of the original repo with our custom ldm directory
  • Download the weights from here and move them to stable-diffusion/models/ldm/ (This will be ckpt_path variable in train-scripts/train-esd.py)
  • [Only for training] To convert your trained models to diffusers download the diffusers Unet config from here (This will be diffusers_config_path variable in train-scripts/train-esd.py)

Training Guide

After installation, follow these instructions to train a custom ESD model:

  • cd stable-diffusion to the main repository of stable-diffusion
  • [IMPORTANT] Edit train-script/train-esd.py and change the default argparser values according to your convenience (especially the config paths)
  • To choose train_method, pick from following 'xattn','noxattn', 'selfattn', 'full'
  • python train-scripts/train-esd.py --prompt 'your prompt' --train_method 'your choice of training' --devices '0,1'

Note that the default argparser values must be changed!

The optimization process for erasing undesired visual concepts from pre-trained diffusion model weights involves using a short text description of the concept as guidance. The ESD model is fine-tuned with the conditioned and unconditioned scores obtained from frozen SD model to guide the output away from the concept being erased. The model learns from it's own knowledge to steer the diffusion process away from the undesired concept.

Generating Images

To generate images from one of the custom models use the following instructions:

  • To use eval-scripts/generate-images.py you would need a csv file with columns prompt, evaluation_seed and case_number. (Sample data in data/)
  • To generate multiple images per prompt use the argument num_samples. It is default to 10.
  • The path to model can be customised in the script.
  • It is to be noted that the current version requires the model to be in saved in stable-diffusion/compvis-<based on hyperparameters>/diffusers-<based on hyperparameters>.pt
  • python eval-scripts/generate-images.py --model_name='compvis-word_VanGogh-method_xattn-sg_3-ng_1-iter_1000-lr_1e-05' --prompts_path 'stable-diffusion/art_prompts.csv' --save_path 'evaluation_folder' --num_samples 10

Citing our work

The preprint can be cited as follows

@inproceedings{gandikota2023erasing,
  title={Erasing Concepts from Diffusion Models},
  author={Rohit Gandikota and Joanna Materzy\'nska and Jaden Fiotto-Kaufman and David Bau},
  booktitle={Proceedings of the 2023 IEEE International Conference on Computer Vision},
  year={2023}
}

More Repositories

1

sliders

Concept Sliders for Precise Control of Diffusion Models
Jupyter Notebook
964
star
2

unified-concept-editing

Unified Concept Editing in Diffusion Models
Python
96
star
3

hiding-audio-in-images

Generative Models to hide Audio inside Images using custom loss functions and Spectrogram Analysis
Python
17
star
4

Stock-News-Scrapping-With-Python

API for scrapping news on stock market for sentiment analysis and stock prediction
Python
13
star
5

Hiding-Images-using-VAE-Genarative-Adversarial-Networks

Variational Autoencoder-Generative Adversarial Network (VAE-GAN) to hide data inside images
Python
10
star
6

erasing-llm

Erasing conceptual knowledge from language models through low-rank fine-tuning
Python
9
star
7

bert-qa

This project shows the usage of hugging face framework to answer questions using a deep learning model for NLP called BERT. This work can be adopted and used in many application in NLP like smart assistant or chat-bot or smart information center.
Python
9
star
8

sar2optical

A Conditional Patch GAN for synthesis of optical images from SAR data as a 24X7, all weather disaster surveillance
Python
8
star
9

automatic-image-quality

Automatic Image Quality Analysis (AIQA) has become a very crucial module in remote sensing industry. With increasing competition and institutions that provide remote sensing images, the quality of images provided to the users has a huge impact.
Python
5
star
10

Land-Use-Land-Cover-Classification-of-Satellite-Images-using-Deep-Learning

This work discusses how high resolution satellite images are classified into various classes like cloud, vegetation, water and miscellaneous, using feed forward neural network. Open source python libraries like GDAL and keras were used in this work. This work is generic and can be used for satellite images of any resolution, but with MX band sensors.
Python
5
star
11

satellite-to-map

Generative Model to generate Map layers from Satellite Data
Python
4
star
12

NLP-based-Smart-Search-for-Satellite-Data-Ordering

Text-based and Voice-based search for satellite data ordering will massively improve user usability in terms of time spent and ease. This work focuses on satellite specific lingo and uses databases to search for data.
Python
2
star
13

Real-Time-Cloud-Detection-of-Satellite-Images-during-Acquisition

This project deals with the real time cloud detection of the ongoing acquisition data of satellite images. For this end, we use a simple and light MLP for classification of the image pixels. This work can classify the satellite images of size ranges till 64000 pixels width.
Python
2
star
14

cdqn-detect

This project harnesses deep reinforcement learning to detect cars in aerial images
Jupyter Notebook
1
star
15

Hiding-Video-in-Images-using-Deep-Generative-Adversarial-Networks

This is a preliminary attempt on hiding video data inside images using deep learning. We design a custom adversarial network with custom losses and additional discriminator. We call this multi-discriminator and multi-objective training framework.
Python
1
star
16

Image-Rotation-Angle-Detection-with-Python

This code can be used for finding the angle that the image has been rotated by. Especially is tested on satellite data where geo-referencing rotates the image.
Python
1
star
17

deprecated-code

This repository contains our initial experiments to study code deprecation in codeLLMs
Jupyter Notebook
1
star
18

progressive-diffusion

We explore the concept of progressive growth of network layers in denoising diffusion probabilistic models.
Python
1
star