• Stars
    star
    625
  • Rank 71,862 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 1 year ago
  • Updated 10 days ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions


This is a beta release (public testing).

A Library for Understanding and Improving PyTorch Models via Interventions

Interventions on model-internal states are fundamental operations in many areas of AI, including model editing, steering, robustness, and interpretability. To facilitate such research, we introduce pyvene, an open-source Python library that supports customizable interventions on a range of different PyTorch modules. pyvene supports complex intervention schemes with an intuitive configuration format, and its interventions can be static or include trainable parameters.

Getting Started: [Main pyvene 101]

Installation

Since we are currently beta-testing, it is recommended to install pyvene by,

git clone [email protected]:stanfordnlp/pyvene.git

and add pyvene into your system path in python via,

import sys
sys.path.append("<Your Path to Pyvene>")

import pyvene as pv

Alternatively, you can do

pip install git+https://github.com/stanfordnlp/pyvene.git

or

pip install pyvene

Wrap , Intervene and Share

You can intervene with supported models as,

import torch
import pyvene as pv

_, tokenizer, gpt2 = pv.create_gpt2()

pv_gpt2 = pv.IntervenableModel({
    "layer": 0, "component": "block_output",
    "source_representation": torch.zeros(gpt2.config.n_embd)
}, model=gpt2)

orig_outputs, intervened_outputs = pv_gpt2(
    base = tokenizer("The capital of Spain is", return_tensors="pt"), 
    unit_locations={"base": 3}
)
print(intervened_outputs.last_hidden_state - orig_outputs.last_hidden_state)

which returns,

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0483, -0.1212, -0.2816,  ...,  0.1958,  0.0830,  0.0784],
         [ 0.0519,  0.2547, -0.1631,  ...,  0.0050, -0.0453, -0.1624]]])

IntervenableModel Loaded from HuggingFace Directly

The following codeblock can reproduce honest_llama-2 chat from the paper Inference-Time Intervention: Eliciting Truthful Answers from a Language Model. The added activations are only ~0.14MB on disk!

# others can download from huggingface and use it directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyvene as pv

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.bfloat16,
).to("cuda")

pv_model = pv.IntervenableModel.load(
    "zhengxuanzenwu/intervenable_honest_llama2_chat_7B", # the activation diff ~0.14MB
    model,
)

print("llama-2-chat loaded with interventions:")
q = "What's a cure for insomnia that always works?"
prompt = tokenizer(q, return_tensors="pt").to("cuda")
_, iti_response_shared = pv_model.generate(prompt, max_new_tokens=64, do_sample=False)
print(tokenizer.decode(iti_response_shared[0], skip_special_tokens=True))

With this, once you discover some clever intervention schemes, you can share with others quickly without sharing the actual base LMs or the intervention code!

IntervenableModel as Regular nn.Module

You can also use the pv_gpt2 just like a regular torch model component inside another model, or another pipeline as,

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict

class ModelWithIntervenables(nn.Module):
    def __init__(self):
        super(ModelWithIntervenables, self).__init__()
        self.pv_gpt2 = pv_gpt2
        self.relu = nn.ReLU()
        self.fc = nn.Linear(768, 1)
        # Your other downstream components go here

    def forward(
        self, 
        base,
        sources: Optional[List] = None,
        unit_locations: Optional[Dict] = None,
        activations_sources: Optional[Dict] = None,
        subspaces: Optional[List] = None,
    ):
        _, counterfactual_x = self.pv_gpt2(
            base,
            sources,
            unit_locations,
            activations_sources,
            subspaces
        )
        return self.fc(self.relu(counterfactual_x.last_hidden_state))

Complex Intervention Schema as an Object

One key abstraction that pyvene provides is the encapsulation of the intervention schema. While abstraction provides good user-interfact, pyvene can support relatively complex intervention schema. The following helper function generates the schema configuration for path patching on individual attention heads on the output of the OV circuit (i.e., analyzing causal effect of each individual component):

import pyvene as pv

def path_patching_config(
    layer, last_layer, 
    component="head_attention_value_output", unit="h.pos", 
):
    intervening_component = [
        {"layer": layer, "component": component, "unit": unit, "group_key": 0}]
    restoring_components = []
    if not stream.startswith("mlp_"):
        restoring_components += [
            {"layer": layer, "component": "mlp_output", "group_key": 1}]
    for i in range(layer+1, last_layer):
        restoring_components += [
            {"layer": i, "component": "attention_output", "group_key": 1}
            {"layer": i, "component": "mlp_output", "group_key": 1}
        ]
    intervenable_config = IntervenableConfig(intervening_component + restoring_components)
    return intervenable_config

then you can wrap the config generated by this function to a model. And after you have done your intervention, you can share your path patching with others,

_, tokenizer, gpt2 = pv.create_gpt2()

pv_gpt2 = pv.IntervenableModel(
    path_patching_config(4, gpt2.config.n_layer), 
    model=gpt2
)
# saving the path
pv_gpt2.save(
    save_directory="./your_gpt2_path/"
)
# loading the path
pv_gpt2 = pv.IntervenableModel.load(
    "./tmp/",
    model=gpt2)

Selected Tutorials

Level Tutorial Run in Colab Description
Beginner pyvene 101 Introduce you to the basics of pyvene
Intermediate ROME Causal Tracing Reproduce ROME's Results on Factual Associations with GPT2-XL
Intermediate Intervention v.s. Probing Illustrates how to run trainable interventions and probing with pythia-6.9B
Advanced Trainable Interventions for Causal Abstraction Illustrates how to train an intervention to discover causal mechanisms of a neural model

Contributing to This Library

Please see our guidelines about how to contribute to this repository.

Pull requests, bug reports, and all other forms of contribution are welcomed and highly encouraged! :octocat:

A Little Guide for Causal Abstraction: From Interventions to Gain Interpretability Insights

Basic interventions are fun but we cannot make any causal claim systematically. To gain actual interpretability insights, we want to measure the counterfactual behaviors of a model in a data-driven fashion. In other words, if the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. We also call this alignment search process with model internals.

Understanding Causal Mechanisms with Static Interventions

Here is a more concrete example,

def add_three_numbers(a, b, c):
    var_x = a + b
    return var_x + c

The function solves a 3-digit sum problem. Let's say, we trained a neural network to solve this problem perfectly. "Can we find the representation of (a + b) in the neural network?". We can use this library to answer this question. Specifically, we can do the following,

  • Step 1: Form Interpretability (Alignment) Hypothesis: We hypothesize that a set of neurons N aligns with (a + b).
  • Step 2: Counterfactual Testings: If our hypothesis is correct, then swapping neurons N between examples would give us expected counterfactual behaviors. For instance, the values of N for (1+2)+3, when swapping with N for (2+3)+4, the output should be (2+3)+3 or (1+2)+4 depending on the direction of the swap.
  • Step 3: Reject Sampling of Hypothesis: Running tests multiple times and aggregating statistics in terms of counterfactual behavior matching. Proposing a new hypothesis based on the results.

To translate the above steps into API calls with the library, it will be a single call,

intervenable.evaluate(
    train_dataloader=test_dataloader,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you provide testing data (basically interventional data and the counterfactual behavior you are looking for) along with your metrics functions. The library will try to evaluate the alignment with the intervention you specified in the config.


Understanding Causal Mechanism with Trainable Interventions

The alignment searching process outlined above can be tedious when your neural network is large. For a single hypothesized alignment, you basically need to set up different intervention configs targeting different layers and positions to verify your hypothesis. Instead of doing this brute-force search process, you can turn it into an optimization problem which also has other benefits such as distributed alignments.

In its crux, we basically want to train an intervention to have our desired counterfactual behaviors in mind. And if we can indeed train such interventions, we claim that causally informative information should live in the intervening representations! Below, we show one type of trainable intervention models.interventions.RotatedSpaceIntervention as,

class RotatedSpaceIntervention(TrainableIntervention):
    
    """Intervention in the rotated space."""
    def forward(self, base, source):
        rotated_base = self.rotate_layer(base)
        rotated_source = self.rotate_layer(source)
        # interchange
        rotated_base[:self.interchange_dim] = rotated_source[:self.interchange_dim]
        # inverse base
        output = torch.matmul(rotated_base, self.rotate_layer.weight.T)
        return output

Instead of activation swapping in the original representation space, we first rotate them, and then do the swap followed by un-rotating the intervened representation. Additionally, we try to use SGD to learn a rotation that lets us produce expected counterfactual behavior. If we can find such rotation, we claim there is an alignment. If the cost is between X and Y.ipynb tutorial covers this with an advanced version of distributed alignment search, Boundless DAS. There are recent works outlining potential limitations of doing a distributed alignment search as well.

You can now also make a single API call to train your intervention,

intervenable.train(
    train_dataloader=train_dataloader,
    compute_loss=compute_loss,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use intervenable.evaluate() your interventions in terms of customized objectives.

Related Works in Discovering Causal Mechanism of LLMs

If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.

Star History

Star History Chart

Citation

Library paper is forthcoming. For now, if you use this repository, please consider to cite relevant papers:

  @article{geiger-etal-2023-DAS,
        title={Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Potts, Christopher and Icard, Thomas  and Goodman, Noah},
        year={2023},
        booktitle={arXiv}
  }

  @article{wu-etal-2023-Boundless-DAS,
        title={Interpretability at Scale: Identifying Causal Mechanisms in Alpaca}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Icard, Thomas and Potts, Christopher and Goodman, Noah},
        year={2023},
        booktitle={NeurIPS}
  }

More Repositories

1

dspy

DSPy: The framework for programming—not prompting—foundation models
Python
18,220
star
2

CoreNLP

CoreNLP: A Java suite of core NLP tools for tokenization, sentence segmentation, NER, parsing, coreference, sentiment analysis, etc.
Java
9,678
star
3

stanza

Stanford NLP Python library for tokenization, sentence segmentation, NER, and parsing of many human languages
Python
7,278
star
4

GloVe

Software in C and data files for the popular GloVe model for distributed word representations, a.k.a. word vectors or embeddings
C
6,867
star
5

cs224n-winter17-notes

Course notes for CS224N Winter17
TeX
1,587
star
6

pyreft

ReFT: Representation Finetuning for Language Models
Python
1,137
star
7

treelstm

Tree-structured Long Short-Term Memory networks (http://arxiv.org/abs/1503.00075)
Lua
875
star
8

string2string

String-to-String Algorithms for Natural Language Processing
Jupyter Notebook
533
star
9

python-stanford-corenlp

Python interface to CoreNLP using a bidirectional server-client interface.
Python
516
star
10

mac-network

Implementation for the paper "Compositional Attention Networks for Machine Reasoning" (Hudson and Manning, ICLR 2018)
Python
494
star
11

phrasal

A large-scale statistical machine translation system written in Java.
Java
208
star
12

spinn

SPINN (Stack-augmented Parser-Interpreter Neural Network): fast, batchable, context-aware TreeRNNs
Python
205
star
13

coqa-baselines

The baselines used in the CoQA paper
Python
176
star
14

cocoa

Framework for learning dialogue agents in a two-player game setting.
Python
158
star
15

stanza-old

Stanford NLP group's shared Python tools.
Python
138
star
16

chirpycardinal

Stanford's Alexa Prize socialbot
Python
131
star
17

stanfordnlp

[Deprecated] This library has been renamed to "Stanza". Latest development at: https://github.com/stanfordnlp/stanza
Python
114
star
18

wge

Workflow-Guided Exploration: sample-efficient RL agent for web tasks
Python
109
star
19

pdf-struct

Logical structure analysis for visually structured documents
Python
81
star
20

edu-convokit

Edu-ConvoKit: An Open-Source Framework for Education Conversation Data
Jupyter Notebook
75
star
21

cs224n-web

http://cs224n.stanford.edu
HTML
60
star
22

ColBERT-QA

Code for Relevance-guided Supervision for OpenQA with ColBERT (TACL'21)
40
star
23

stanza-train

Model training tutorials for the Stanza Python NLP Library
Python
37
star
24

phrasenode

Mapping natural language commands to web elements
Python
37
star
25

contract-nli-bert

A baseline system for ContractNLI (https://stanfordnlp.github.io/contract-nli/)
Python
29
star
26

color-describer

Code for Learning to Generate Compositional Color Descriptions
OpenEdge ABL
26
star
27

stanza-resources

23
star
28

python-corenlp-protobuf

Python bindings for Stanford CoreNLP's protobufs.
Python
20
star
29

miniwob-plusplus-demos

Demos for the MiniWoB++ benchmark
17
star
30

multi-distribution-retrieval

Code for our paper Resources and Evaluations for Multi-Distribution Dense Information Retrieval
Python
14
star
31

huggingface-models

Scripts for pushing models to huggingface repos
Python
11
star
32

nlp-meetup-demo

Java
8
star
33

sentiment-treebank

Updated version of SST
Python
8
star
34

en-worldwide-newswire

An English NER dataset built from foreign newswire
Python
7
star
35

plot-data

datasets for plotting
Jupyter Notebook
6
star
36

contract-nli

ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
HTML
4
star
37

plot-interface

Web interface for the plotting project
JavaScript
3
star
38

handparsed-treebank

Extra hand parsed data for training models
Perl
2
star
39

coqa

CoQA -- A Conversational Question Answering Challenge
Shell
2
star
40

pdf-struct-models

A repository for hosting models for https://github.com/stanfordnlp/pdf-struct
HTML
2
star
41

chirpy-parlai-blenderbot-fork

A fork of ParlAI supporting Chirpy Cardinal's custom neural generator
Python
2
star
42

wob-data

Data for QAWoB and FlightWoB web interaction benchmarks from the World of Bits paper (Shi et al., 2017).
Python
2
star
43

pdf-struct-dataset

Dataset for pdf-struct (https://github.com/stanfordnlp/pdf-struct)
HTML
1
star
44

nn-depparser

A re-implementation of nndep using PyTorch.
Python
1
star