Note: this code contains updates compared to the version released when the paper was released. While reproduction of our results should still be possible on this branch, please refer to the v1.0
branch for full reproduction. Updates include bug fixes related to gettings the logits from the models, model loading in some special cases, updates for keeping up-to-date with the newest version of LM-eval and Transformers, ... Note that the interface has almost entirely remained the same.
This repo contains the code for model arithmetic, a comprehensive framework where arithmetic formulas express combinations of LMs and classifiers, thereby biasing the generated text towards or away from desired attributes.
In order to install model arithmetic with Python 3, run
python -m pip install -e .
Model arithmetic allows you to combine prompts, models, and classifiers to create new, precisely controlled LLMs that combine aspects of each component.
For instance, you can easily interpolate between two differently-prompted models as follows:
from model_arithmetic import ModelArithmetic, PromptedLLM
# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"
# define two differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)
# model arithmetic expression
formula1 = M_child - 0.6 * M_adult
# generate text as usual
ma0 = ModelArithmetic(formula1, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma0.generate_text("Write a one-sentence fairy tale."))
# -> [" Oh my gumdrops! Let me tell you a super special fairy tale 'bout a teeny tiny princess who lived in a fluffy white castle with her sparkly unicorn and they had the most amazing adventures together!</s>"]
Note that the generate_text
function can also take a list of input sentences and works with standard arguments such as temperature
, top_p
, top_k
, batch_size
, num_return_sequences
and stop_texts
(a list of strings at which the generation should be stopped). You can also save and load a ModelArithmetic
object:
ma0.to_pretrained('model')
ma0 = ModelArithmetic.from_pretrained('model')
You can integrate classifiers into your model arithmetic expressions. For instance, you can use a classifier to control the formality of your output:
from model_arithmetic import ModelArithmetic, PromptedLLM, Classifier
# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"
# define two differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)
# construct model arithmetic expression
formula1 = M_child - 0.6 * M_adult
# Initialize the classifier, the first and third arguments are used to determine on which completion tokens the classifier should be run (on the 50 most likely tokens of formula1 here). The prompt template shown here ensures that the input sentence is ignored for the classifier guidance
C_formal = Classifier(formula1, "s-nlp/roberta-base-formality-ranker", n_runs_per_sample=50, prompt_template=lambda e, f: "")
# integrate classifier into model arithmetic expression
formula2 = formula1 + C_formal
# generate text as usual
ma = ModelArithmetic(formula2, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma.generate_text("Write a one-sentence fairy tale.", max_length=128))
# -> [' "Once upon a time, in a magical land filled with fluffy clouds and sparkly rainbows, there was a little girl named me who went on a fun adventure with my stuffed unicorn named Mr. Snuggles!"</s>']
You can also use our custom operators to generate text. For instance, you can use the Union operator to add some magic touch to the fairy tale:
from model_arithmetic import ModelArithmetic, PromptedLLM, Union, Classifier
# define model prompt template
prompt_template = lambda formula_string, input_string: f"<s>[INST]<<SYS>>\n{formula_string}\n<</SYS>>\n\n{input_string} [/INST]"
# define three differently-prompted models
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template)
M_magic = PromptedLLM("You are a person who is always talking about magic.", prompt_template=prompt_template)
# construct model arithmetic expression
formula_part1 = M_child - 0.6 * M_adult + 2 * Union(M_child, M_magic)
# integrate classifier in the expression
C_formal = Classifier(formula_part1, "s-nlp/roberta-base-formality-ranker", n_runs_per_sample=50,
prompt_template=lambda e, f: "")
formula = formula_part1 + C_formal
# generate text as usual
ma = ModelArithmetic(formula, default_model="meta-llama/Llama-2-13b-chat-hf")
print(ma.generate_text("Write a one-sentence fairy tale."))
# -> [' "Once upon a time, in a magical forest filled with sparkling flowers and talking animals, there lived a little girl named Lily who had a special gift for conjuring delicious rainbow-colored cupcakes that made everyone who ate them feel happy and dance with joy!"</s>']
A formula can have terms using different models, as long as all models have the same tokenizer. One can specify a specific model for a certain term by setting the model
parameter:
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template, model="meta-llama/Llama-2-7b-chat-hf")
The selected model can also be a PreTrainedModel
instead of a string
.
Models are by default loaded in bfloat16 format. You can change this by specifying the dtype
parameter in the ModelArithmetic
constructor:
ma = ModelArithmetic(formula, default_model="meta-llama/Llama-2-13b-chat-hf", dtype=torch.float32)
Speculative sampling can be performed by initializing the prompted models with the extra speculative_factor
parameter and setting the do_speculation
parameter in the generation function to True
:
...
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template)
M_adult = PromptedLLM("You are an adult.", prompt_template=prompt_template, speculative_factor=4)
...
print(ma0.generate_text("Write a one-sentence fairy tale.", do_speculation=True))
Note that one prompted model should always have speculative_factor=1
(the default value).
By default, we process the key-value cache stored by models since this is required for speculative sampling. Since different models use key-value caching differently, this can result in errors. We therefore included the run_eager
parameter in the initialization of the prompted model to disable all speculative sampling which should fix this issue if it occurs:
M_child = PromptedLLM("You are a child.", prompt_template=prompt_template, run_eager=True)
Finally, the library provides some other operators that can be used in formulas, of which we present a few here. The TopPTopK
operator allows the use of nucleus and top-k sampling within a formula. The following ensures that the output token is always in the top 10 words of model1
:
formula = TopPTopK(model1, top_k=10) + model2
The Superseded
operator implements speculative sampling directly:
formula = Superseded(small_model, large_model)
Model arithmetic is compatible with the LM Evaluation harness. In order to run benchmarks from the harness, you need to install the package as described on their GitHub page. An example of how to use our tool with the lm evaluation harness is shown in scripts/evaluate_lm_eval.py
.
For the reproduction of the results presented in our paper, Controlled Text Generation via Language Model Arithmetic, we advice to run the code with the exact environment we used (Nvidia H100 80GB GPU on a Linux machine). To do so install Conda and run
conda create -n model_arithmetic python=3.10
conda activate model_arithmetic
python -m pip install -r requirements.txt
python -m pip install -e .
API Keys for both the PERSPECTIVE API and OpenAI need to be available in the environment variables. Alternatively, they can be placed in the file src/.env
as
PERSPECTIVE_API_KEY="[YOUR API KEY]"
OPENAI_API_KEY="[YOUR API KEY]"
The processed datasets are in the data/datasets
folder. You can reproduce our results using these datasets by running
bash scripts/main.sh
This will finetune a classifier for the toxicity and sentiment control tasks, and reproduce the results from all sections of our paper. Results in CSV-format can afterwards be found in eval/processed
and our figures in eval/plots
.
Alternatively, you can download the raw datasets and put them in the data/datasets
folder:
- Alpaca Data
- Jigsaw Toxicity Dataset (the
all_data.csv
file should be downloaded and extracted in thedata/datasets
folder) - Politically Incorrect 4chan Messages (file should be unzipped and placed in the top level of the
data/datasets
folder) - IMDB Dataset
You can then reproduce the results using
bash scripts/main_preprocess.sh
We note that part of our preprocessing code got lost, specifically for preparing the dataset that is used for finetuning the toxicity classifier. Running the code without using the preprocessed datasets might therefore result in slightly different numbers when they involve the finetuned classifier.
@article{dekoninck-2023-controlled,
author = {Jasper Dekoninck and
Marc Fischer and
Luca Beurer{-}Kellner and
Martin T. Vechev},
title = {Controlled Text Generation via Language Model Arithmetic},
journal = {CoRR},
volume = {abs/2311.14479},
year = {2023},
}