• Stars
    star
    563
  • Rank 79,150 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated 7 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Locating and editing factual associations in GPT (NeurIPS 2022)

Rank-One Model Editing (ROME)

This repository provides an implementation of Rank-One Model Editing (ROME) on auto-regressive transformers (GPU-only). We currently support OpenAI's GPT-2 XL (1.5B) and EleutherAI's GPT-J (6B). The release of a 20B GPT-like model from EleutherAI is expected soon; we hope to support it ASAP.

Feel free to open an issue if you find any problems; we are actively developing this repository and will monitor tickets closely.

Colab ROME Demo

causal tracing GIF

Table of Contents

  1. Installation
  2. Causal Tracing
  3. Rank-One Model Editing (ROME)
  4. CounterFact
  5. Evaluation
  6. How to Cite

Installation

We recommend conda for managing Python, CUDA, and PyTorch-related dependencies, and pip for everything else. To get started, simply install conda and run:

./scripts/setup_conda.sh

Causal Tracing

notebooks/causal_trace.ipynb demonstrates Causal Tracing, which can be modified to apply tracing to the processing of any statement.

causal tracing GIF

Rank-One Model Editing (ROME)

notebooks/rome.ipynb demonstrates ROME. The API is simple; one simply has to specify a requested rewrite of the following form:

request = {
    "prompt": "{} plays the sport of",
    "subject": "LeBron James",
    "target_new": {
        "str": "football"
    }
}

Several similar examples are included in the notebook.

CounterFact

Details coming soon!

Evaluation

See baselines/ for a description of the available baselines.

Running the Full Evaluation Suite

experiments/evaluate.py can be used to evaluate any method in baselines/. To get started (e.g. using ROME on GPT-2 XL), run:

python3 -m experiments.evaluate \
    --alg_name=ROME \
    --model_name=gpt2-xl \
    --hparams_fname=gpt2-xl.json

Results from each run are stored at results/<method_name>/run_<run_id> in a specific format:

results/
|__ ROME/
    |__ run_<run_id>/
        |__ params.json
        |__ case_0.json
        |__ case_1.json
        |__ ...
        |__ case_10000.json

To summarize the results, you can use experiments/summarize.py:

python3 -m experiments.summarize --dir_name=ROME --runs=run_<run_id>

Running python3 -m experiments.evaluate -h or python3 -m experiments.summarize -h provides details about command-line flags.

Integrating New Editing Methods

Say you have a new method X and want to benchmark it on CounterFact. To integrate X with our runner:

  • Subclass HyperParams into XHyperParams and specify all hyperparameter fields. See ROMEHyperParameters for an example implementation.
  • Create a hyperparameters file at hparams/X/gpt2-xl.json and specify some default values. See hparams/ROME/gpt2-xl.json for an example.
  • Define a function apply_X_to_model which accepts several parameters and returns (i) the rewritten model and (ii) the original weight values for parameters that were edited (in the dictionary format {weight_name: original_weight_value}). See rome/rome_main.py for an example.
  • Add X to ALG_DICT in experiments/evaluate.py by inserting the line "X": (XHyperParams, apply_X_to_model).

Finally, run the main scripts:

python3 -m experiments.evaluate \
    --alg_name=X \
    --model_name=gpt2-xl \
    --hparams_fname=gpt2-xl.json

python3 -m experiments.summarize --dir_name=X --runs=run_<run_id>

Note on Cross-Platform Compatibility

We currently only support methods that edit autoregressive HuggingFace models using the PyTorch backend. We are working on a set of general-purpose methods (usable on e.g. TensorFlow and without HuggingFace) that will be released soon.

How to Cite

@article{meng2022locating,
  title={Locating and Editing Factual Associations in {GPT}},
  author={Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  year={2022}
}