Skip to content

Commit

Permalink
use prompt instruction formatting for eval
Browse files Browse the repository at this point in the history
  • Loading branch information
constanzafierro committed Dec 12, 2023
1 parent c6dbeb1 commit fd6b110
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
15 changes: 13 additions & 2 deletions classifier/classifier_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
Trainer,
)
from glob import glob
from inference import prepare_prompt, DEF_TEMPLATE_TO_USE, DEF_INSTRUCTION


def replace_subject(prompt_format, tokenizer, example):
def replace_subject(prompt_format, tokenizer, prepare_prompt_func, example):
query = example["query"].replace("_X_ .", "_X_.")
text = query.replace("_X_.", example["answer"][0]["name"]).strip()
if prompt_format != "{}":
text = text[0].lower() + text[1:]
text = prompt_format.format(text)
text = prepare_prompt_func(text).strip()
return {"text": text, **tokenizer(text)}


Expand Down Expand Up @@ -70,7 +72,16 @@ def main(args, device):
assert len(model_path) == 1, model_path
model_path = model_path[0]
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenized_ds = ds.map(partial(replace_subject, args.prompt_format, tokenizer))
tokenized_ds = ds.map(
partial(
replace_subject,
args.prompt_format,
tokenizer,
lambda q: prepare_prompt(
q, model_path, DEF_INSTRUCTION, DEF_TEMPLATE_TO_USE
),
)
)
print("Example of training example:", tokenized_ds["train"][0])
print("Loading model")
id2label = {1: "MUTABLE", 0: "IMMUTABLE"}
Expand Down
8 changes: 8 additions & 0 deletions classifier/compute_mdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def print_metrics(ds, dir_name):

wu_llama2 = {"1-1": "0.2", "1-n": "0.1"}
wu_alpaca = {"1-1": "0.2", "1-n": "0.0"}
wu_falcon = {"1-1": "0.2", "1-n": "0.0"}
for clf_type in ["1-1", "1-n"]:
print(">>>>>>>>>>>>>>>>>>>>>>>>>", clf_type)
model_to_results_dir = {
"llama-7b": "/projects/nlp/data/constanzam/mdl_mutability",
"alpaca-7b": f"/projects/nlp/data/constanzam/mdl_mutability/alpaca-7b/fm_dataset_{clf_type}/",
"llama2-7b": f"/projects/nlp/data/constanzam/mdl_mutability/llama2-7b/fm_dataset_{clf_type}/",
"llama2-chat-7b": f"/projects/nlp/data/constanzam/mdl_mutability/llama2-chat-7b/fm_dataset_{clf_type}/",
"falcon-7b": f"/projects/nlp/data/constanzam/mdl_mutability/falcon-7b/fm_dataset_{clf_type}/",
"falcon-instruct-7b": f"/projects/nlp/data/constanzam/mdl_mutability/falcon-instruct-7b/fm_dataset_{clf_type}/",
}
model_to_normal_subfolder = {
"llama-7b": "no_overlap_fix_fm_dataset_1-1"
Expand All @@ -108,6 +111,8 @@ def print_metrics(ds, dir_name):
"llama2-7b": f"lr5e-5_wu{wu_llama2[clf_type]}_",
"alpaca-7b": f"lr5e-5_wu{wu_alpaca[clf_type]}_",
"llama2-chat-7b": "lr5e-5_wu0.2_",
"falcon-7b": f"lr5e-5_wu{wu_falcon[clf_type]}_",
"falcon-instruct-7b": "lr5e-5_wu0.2_",
}
model_to_rand_subfolder = {
"llama-7b": f"no_overlap_fix_rand_fm_dataset_{clf_type}"
Expand All @@ -116,10 +121,13 @@ def print_metrics(ds, dir_name):
"llama2-7b": f"lr5e-5_wu{wu_llama2[clf_type]}_rand",
"alpaca-7b": f"lr5e-5_wu{wu_alpaca[clf_type]}_rand",
"llama2-chat-7b": "lr5e-5_wu0.2_rand",
"falcon-7b": f"lr5e-5_wu{wu_falcon[clf_type]}_rand",
"falcon-instruct-7b": "lr5e-5_wu0.2_rand",
}

ds = load_dataset(f"coastalcph/mutability_classifier-{clf_type}")
for model in model_to_results_dir.keys():
print()
print(model)
print("--- normal ----")
normal_dir = os.path.join(
Expand Down
6 changes: 2 additions & 4 deletions classifier/mdl_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@
Trainer,
TrainingArguments,
)
from inference import TEMPLATES, prepare_prompt
from inference import DEF_TEMPLATE_TO_USE, DEF_INSTRUCTION, prepare_prompt

logger = logging.getLogger(__name__)
TEMPLATE_TO_USE = "query_in_response"
INSTRUCTION = "Complete the fact in as few words as possible"


@dataclass
Expand Down Expand Up @@ -262,7 +260,7 @@ def main(device):
partial(
replace_subject,
prepare_prompt=lambda q: prepare_prompt(
q, model_args.model_name_or_path, INSTRUCTION, TEMPLATE_TO_USE
q, model_args.model_name_or_path, DEF_INSTRUCTION, DEF_TEMPLATE_TO_USE
),
tokenizer=tokenizer,
)
Expand Down
6 changes: 4 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_BEAMS = 1
MAX_ANSWER_LENGTH = 10
DEF_TEMPLATE_TO_USE = "query_in_response"
DEF_INSTRUCTION = "Complete the fact in as few words as possible"

TEMPLATES = {
"query_in_instructions": (
Expand Down Expand Up @@ -241,13 +243,13 @@ def main(args):
parser.add_argument(
"--template",
type=str,
default="query_in_response",
default=DEF_TEMPLATE_TO_USE,
help="query_in_instructions, query_in_response or query_in_input",
)
parser.add_argument(
"--instruction",
type=str,
default="Complete the fact in as few words as possible",
default=DEF_INSTRUCTION,
)
parser.add_argument(
"--output_dir",
Expand Down

0 comments on commit fd6b110

Please sign in to comment.