Skip to content

Commit

Permalink
use random template in context
Browse files Browse the repository at this point in the history
  • Loading branch information
constanzafierro committed Dec 6, 2023
1 parent 8ace77b commit 69a6dcf
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions inference_updates.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import argparse
import collections
import json
import os

import numpy as np
import torch
import wandb
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import collections
from inference import prepare_prompt, get_scores, get_generation_config
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from inference import get_generation_config, get_scores, prepare_prompt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -38,6 +37,7 @@ def main(args):

outputs = {key: [] for key in ["raw_predictions", "predictions"]}
updated_counts_mutability = collections.defaultdict(int)
rng = np.random.default_rng(42)
for ex_i, ex in enumerate(tqdm(ds)):
relation = ex["relation"]
subject = ex["query"]["label"]
Expand All @@ -53,8 +53,7 @@ def main(args):
print("templates", templates)
raise Exception("prompt not in templates")
templates.remove(prompt)
context = list(templates)[0]
# TODO: should we run over all?
context = list(templates)[rng.choice(len(templates), 1)[0]]
new_target = ex["updates"][0]
query = "Imagine that {} {}. Then, {}".format(
context.replace("[X]", subject), new_target, prompt.replace("[X]", subject)
Expand Down Expand Up @@ -107,8 +106,14 @@ def main(args):
print("query", query)
print("new_target", new_target)
print("answer", answer)
for k, v in updated_counts_mutability.items():
wandb.run.summary[k] = v
for mutability in list(
set([k.split("_")[0] for k in updated_counts_mutability.keys()])
):
total = updated_counts_mutability[f"{mutability}_total"]
succ = updated_counts_mutability[f"{mutability}_succ"]
wandb.run.summary[f"{mutability}_total"] = total
wandb.run.summary[f"{mutability}_succ"] = succ
wandb.run.summary[f"{mutability}_succ_rate"] = succ / total

print("Writing outputs")
for key in outputs:
Expand Down

0 comments on commit 69a6dcf

Please sign in to comment.