• Stars
    star
    201
  • Rank 194,491 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 1 year ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Controlled Text Generation via Language Model Arithmetic

Controlled Text Generation via Language Model Arithmetic

Note: this code contains updates compared to the version released when the paper was released. While reproduction of our results should still be possible on this branch, please refer to the v1.0 branch for full reproduction. Updates include bug fixes related to gettings the logits from the models, model loading in some special cases, updates for keeping up-to-date with the newest version of LM-eval and Transformers, ... Note that the interface has almost entirely remained the same.

This repo contains the code for model arithmetic, a comprehensive framework where arithmetic formulas express combinations of LMs and classifiers, thereby biasing the generated text towards or away from desired attributes.

Overview

In order to install model arithmetic with Python 3, run

python -m pip install -e .

Getting Started

Model arithmetic allows you to combine prompts, models, and classifiers to create new, precisely controlled LLMs that combine aspects of each component.

For instance, you can easily interpolate between two differently-prompted models as follows:

from model_arithmetic import ModelArithmetic, PromptedLLM

# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"

# define two differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)

# model arithmetic expression
formula1 = M_child - 0.6 * M_adult

# generate text as usual
ma0 = ModelArithmetic(formula1, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma0.generate_text("Write a one-sentence fairy tale."))
# -> ["  Oh my gumdrops! Let me tell you a super special fairy tale 'bout a teeny tiny princess who lived in a fluffy white castle with her sparkly unicorn and they had the most amazing adventures together!</s>"]

Note that the generate_text function can also take a list of input sentences and works with standard arguments such as temperature, top_p, top_k, batch_size, num_return_sequences and stop_texts (a list of strings at which the generation should be stopped). You can also save and load a ModelArithmetic object:

ma0.to_pretrained('model')
ma0 = ModelArithmetic.from_pretrained('model')

Integrating Classifiers

You can integrate classifiers into your model arithmetic expressions. For instance, you can use a classifier to control the formality of your output:

from model_arithmetic import ModelArithmetic, PromptedLLM, Classifier

# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"

# define two differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)

# construct model arithmetic expression
formula1 = M_child - 0.6 * M_adult

# Initialize the classifier, the first and third arguments are used to determine on which completion tokens the classifier should be run (on the 50 most likely tokens of formula1 here). The prompt template shown here ensures that the input sentence is ignored for the classifier guidance
C_formal = Classifier(formula1, "s-nlp/roberta-base-formality-ranker", n_runs_per_sample=50, prompt_template=lambda e, f: "")

# integrate classifier into model arithmetic expression
formula2 = formula1 + C_formal

# generate text as usual
ma = ModelArithmetic(formula2, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma.generate_text("Write a one-sentence fairy tale.", max_length=128))
# -> ['  "Once upon a time, in a magical land filled with fluffy clouds and sparkly rainbows, there was a little girl named me who went on a fun adventure with my stuffed unicorn named Mr. Snuggles!"</s>']

Union and intersection

You can also use our custom operators to generate text. For instance, you can use the Union operator to add some magic touch to the fairy tale:

from model_arithmetic import ModelArithmetic, PromptedLLM, Union, Classifier

# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"

# define three differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)
M_magic = PromptedLLM("You are a person who is always talking about magic.", prompt_template=prompt_template)

# construct model arithmetic expression
formula_part1 = M_child - 0.6 * M_adult + 2 * Union(M_child, M_magic)

# integrate classifier in the expression
C_formal = Classifier(formula_part1, "s-nlp/roberta-base-formality-ranker", n_runs_per_sample=50, 
                      prompt_template=lambda e, f: "")

formula = formula_part1 + C_formal

# generate text as usual
ma = ModelArithmetic(formula, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma.generate_text("Write a one-sentence fairy tale."))
# -> ['  "Once upon a time, in a magical forest filled with sparkling flowers and talking animals, there lived a little girl named Lily who had a special gift for conjuring delicious rainbow-colored cupcakes that made everyone who ate them feel happy and dance with joy!"</s>']

About models

A formula can have terms using different models, as long as all models have the same tokenizer. One can specify a specific model for a certain term by setting the model parameter:

M_child = PromptedLLM("You are a child.", prompt_template=prompt_template, model="meta-llama/Llama-2-7b-chat-hf")

The selected model can also be a PreTrainedModel instead of a string.

Models are by default loaded in bfloat16 format. You can change this by specifying the dtype parameter in the ModelArithmetic constructor:

ma = ModelArithmetic(formula, default_model="meta-llama/Llama-2-13b-chat-hf", dtype=torch.float32)

Speculative sampling

Speculative sampling can be performed by initializing the prompted models with the extra speculative_factor parameter and setting the do_speculation parameter in the generation function to True:

...
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template, speculative_factor=4)
...
print(ma0.generate_text("Write a one-sentence fairy tale.", do_speculation=True))

Note that one prompted model should always have speculative_factor=1 (the default value).

Eager mode

By default, we process the key-value cache stored by models since this is required for speculative sampling. Since different models use key-value caching differently, this can result in errors. We therefore included the run_eager parameter in the initialization of the prompted model to disable all speculative sampling which should fix this issue if it occurs:

M_child = PromptedLLM("You are a child.", prompt_template=prompt_template, run_eager=True)

Other Operators

Finally, the library provides some other operators that can be used in formulas, of which we present a few here. The TopPTopK operator allows the use of nucleus and top-k sampling within a formula. The following ensures that the output token is always in the top 10 words of model1:

formula = TopPTopK(model1, top_k=10) + model2

The Superseded operator implements speculative sampling directly:

formula = Superseded(small_model, large_model)

LM Evaluation Harness

Model arithmetic is compatible with the LM Evaluation harness. In order to run benchmarks from the harness, you need to install the package as described on their GitHub page. An example of how to use our tool with the lm evaluation harness is shown in scripts/evaluate_lm_eval.py.

Reproducing results

For the reproduction of the results presented in our paper, Controlled Text Generation via Language Model Arithmetic, we advice to run the code with the exact environment we used (Nvidia H100 80GB GPU on a Linux machine). To do so install Conda and run

conda create -n model_arithmetic python=3.10
conda activate model_arithmetic
python -m pip install -r requirements.txt
python -m pip install -e .

API Keys for both the PERSPECTIVE API and OpenAI need to be available in the environment variables. Alternatively, they can be placed in the file src/.env as

PERSPECTIVE_API_KEY="[YOUR API KEY]"
OPENAI_API_KEY="[YOUR API KEY]"

The processed datasets are in the data/datasets folder. You can reproduce our results using these datasets by running

bash scripts/main.sh

This will finetune a classifier for the toxicity and sentiment control tasks, and reproduce the results from all sections of our paper. Results in CSV-format can afterwards be found in eval/processed and our figures in eval/plots.

Alternatively, you can download the raw datasets and put them in the data/datasets folder:

You can then reproduce the results using

bash scripts/main_preprocess.sh

We note that part of our preprocessing code got lost, specifically for preparing the dataset that is used for finetuning the toxicity classifier. Running the code without using the preprocessed datasets might therefore result in slightly different numbers when they involve the finetuned classifier.

Cite this work

@article{dekoninck-2023-controlled,
  author       = {Jasper Dekoninck and
                  Marc Fischer and
                  Luca Beurer{-}Kellner and
                  Martin T. Vechev},
  title        = {Controlled Text Generation via Language Model Arithmetic},
  journal      = {CoRR},
  volume       = {abs/2311.14479},
  year         = {2023},
}

More Repositories

1

lmql

A language for constraint-guided and efficient LLM programming.
Python
3,619
star
2

silq

Q#
608
star
3

securify2

Securify v2.0
Solidity
587
star
4

debin

Machine Learning to Deobfuscate Binaries
Python
412
star
5

eran

ETH Robustness Analyzer for Deep Neural Networks
Python
313
star
6

diffai

A certifiable defense against adversarial examples by training neural networks to be provably robust
Python
217
star
7

securify

[DEPRECATED] Security Scanner for Ethereum Smart Contracts
Java
215
star
8

Nice2Predict

Learning framework for program property prediction
C++
201
star
9

ilf

AI based fuzzer based on imitation learning
Python
149
star
10

ELINA

ELINA: ETH LIbrary for Numerical Analysis
C++
129
star
11

psi

Exact Inference Engine for Probabilistic Programs
JetBrains MPS
123
star
12

sven

Python
95
star
13

dl2

DL2 is a framework that allows training neural networks with logical constraints over numerical values in the network (e.g. inputs, outputs, weights) and to query networks for inputs fulfilling a logical formula.
Python
82
star
14

zkay

A programming language and compiler which enable automatic compilation of intuitive data privacy specifications to NIZK-enabled private smart contracts.
Python
81
star
15

astarix

AStarix: Fast and Optimal Sequence-to-Graph Aligner
C++
72
star
16

TFix

JavaScript
66
star
17

fastsmt

Learning to Solve SMT Formulas Fast
SMT
63
star
18

learch

C++
38
star
19

llmprivacy

Python
36
star
20

soltix

SOLTIX: Scalable automated framework for testing Solidity compilers.
Java
33
star
21

ChatProtect

This is the code for the paper "Self-contradictory Hallucinations of Large Language Models: Evaluation, Detection and Mitigation".
Python
33
star
22

probabilistic-forecasts-attacks

Python
30
star
23

colt

Convex Layerwise Adversarial Training (COLT)
Python
29
star
24

SafeCoder

Python
27
star
25

lcifr

Learning Certified Individually Fair Representations
Python
24
star
26

adaptive-auto-attack

Python
23
star
27

dp-sniper

A machine-learning-based tool for discovering differential privacy violations in black-box algorithms.
Python
23
star
28

verx-benchmarks

20
star
29

lamp

LAMP: Extracting Text from Gradients with Language Model Priors (NeurIPS '22)
Python
20
star
30

dp-finder

Differential Privacy Testing System
Python
19
star
31

bayonet

Probabilistic Computer Network Analysis
D
18
star
32

phoenix

Private and Reliable Neural Network Inference (CCS '22)
C++
18
star
33

fnf

Python
16
star
34

EventRacer

A race detection tool for event driven applications.
C++
16
star
35

learning-real-bug-detector

Python
16
star
36

lassi

Latent Space Smoothing for Individually Fair Representations (ECCV 2022)
Python
15
star
37

deepg

Certifying Geometric Robustness of Neural Networks
Python
15
star
38

vscode-silq

TypeScript
15
star
39

zapper

Rust
15
star
40

robust-code

Adversarial Robustness for Code
Python
13
star
41

watermark-stealing

Watermark Stealing in Large Language Models (ICML '24)
Python
13
star
42

guiding-synthesizers

Guiding Program Synthesis by Learning to Generate Examples
Python
12
star
43

learning-to-configure-networks

[NeurIPS'22] Learning to Configure Computer Networks with Neural Algorithmic Reasoning
12
star
44

SABR

Python
11
star
45

bayes-framework-leakage

Python
11
star
46

smoothing-ensembles

[ICLR 2022] Boosting Randomized Smoothing with Variance Reduced Classifiers
Python
11
star
47

UniversalCertificationTheory

Universal Approximation with Certified Networks
Python
10
star
48

llm-quantization-attack

Python
10
star
49

eth-sri.github.io

SRI Group Website
HTML
9
star
50

ModelsPHOG

Synthesized models for PHOG to make the results reproducible by the research community
C++
9
star
51

segmentation-smoothing

Provable robustness for segmentation tasks.
9
star
52

3dcertify

3DCertify is the first verifier to certify robustness of point cloud models against semantic transformations and point perturbations
Python
8
star
53

prover

Verifier for Deep Neural Network Audio Processing
Python
7
star
54

proof-sharing

CAV'22 paper to speed up Neural Network Verification.
Python
7
star
55

mn-bab

[ICLR 2022] Complete Verification via Multi-Neuron Relaxation Guided Branch-and-Bound
Python
7
star
56

ACE

Python
7
star
57

DFENCE

Dynamic Analysis and Synthesis System for Relaxed Memory Models
C++
6
star
58

Delta-Siege

Python
6
star
59

automated-error-analysis

Automated Classification of Model Errors on ImageNet (NeurIPS 2023)
Jupyter Notebook
6
star
60

R4

C++
5
star
61

drs

[NeurIPS 2022] (De-)Randomized Smoothing for Decision Stump Ensembles
Terra
4
star
62

paradox

On the Paradox of Certified Training (TMLR 10/2022)
Python
4
star
63

fare

FARE: Provably Fair Representation Learning with Practical Certificates (ICML '23)
Shell
4
star
64

Unqomp

Automated Uncomputation for Quantum Programs
Python
4
star
65

fairness-feedback-nlp

Human-Guided Fair Classification for NLP (ICLR 2023, Spotlight)
Python
4
star
66

Spire

C#
3
star
67

TAPS

Python
3
star
68

inferui

InferUI: Robust Relational Layouts Synthesis from Examples for Android
C++
3
star
69

abstraqt

OpenQASM
3
star
70

transformation-smoothing

Randomized Smoothing for Parametric (Image) Transformations
Python
3
star
71

cuts

Python
3
star
72

ACES

[SRML@ICLR 2022] Robust and Accurate -- Compositional Architectures for Randomized Smoothing
Python
2
star
73

synthetiq

OpenQASM
2
star
74

DeepT

Python
2
star
75

ncm

Trace Based Supervision for Neural Architectures
2
star
76

malicious-contamination

Python
2
star
77

CRAFT

Python
1
star
78

fedavg_leakage

Python
1
star
79

Reqomp

Python
1
star
80

ibp-propagation-tightness

Python
1
star
81

tableak

TabLeak: Tabular Data Leakage in Federated Learning
1
star
82

domino

1
star