Skip to content

Commit

Permalink
Revert "use simplified rindex function to find target token probs"
Browse files Browse the repository at this point in the history
This reverts commit 5b7ecea.
  • Loading branch information
jpgard committed May 3, 2023
1 parent 4b200ba commit d05fe43
Showing 1 changed file with 70 additions and 77 deletions.
147 changes: 70 additions & 77 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import importlib
import json
from math import ceil
import os
import random
import uuid
Expand All @@ -10,17 +11,21 @@
import more_itertools
import numpy as np
import torch
from tqdm import tqdm

from coco_metric import compute_cider, postprocess_captioning_generation
from eval_datasets import COCOFlickrDataset, VQADataset, ImageNetDataset
from tqdm import tqdm

from open_flamingo.eval.ok_vqa_utils import postprocess_ok_vqa_generation
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
from open_flamingo.eval.classification import (
compute_per_sample_probs,
compute_per_sample_loss,
)
from open_flamingo.eval.imagenet_utils import (
openai_imagenet_classnames,
IMAGENET_1K_CLASS_ID_TO_LABEL,
)
from open_flamingo.eval.ok_vqa_utils import postprocess_ok_vqa_generation
from open_flamingo.src.flamingo import Flamingo
from vqa_metric import compute_vqa_accuracy, postprocess_vqa_generation
from open_flamingo.eval import eval_model

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -42,8 +47,7 @@
help="Seeds to use for each trial for picking demonstrations and eval sets",
)
parser.add_argument(
"--num_samples", type=int, default=5000,
help="Number of samples to evaluate on"
"--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
)

parser.add_argument("--batch_size", type=int, default=8)
Expand Down Expand Up @@ -225,8 +229,7 @@ def main():
annotations_json_path=args.ok_vqa_annotations_json_path,
vqa_dataset="ok_vqa",
)
print(
f"Shots {shot} Trial {trial} OK-VQA score: {ok_vqa_score}")
print(f"Shots {shot} Trial {trial} OK-VQA score: {ok_vqa_score}")
scores.append(ok_vqa_score)
print(f"Shots {shot} Mean OK-VQA score: {np.mean(scores)}")
results["ok_vqa"].append(
Expand Down Expand Up @@ -297,22 +300,20 @@ def get_random_indices(num_samples, query_set_size, full_dataset, seed):
return random_indices


def prepare_eval_samples_and_dataset(full_dataset, random_indices,
query_set_size):
def prepare_eval_samples_and_dataset(full_dataset, random_indices, query_set_size):
# get in context samples
in_context_samples = [full_dataset[i] for i in
random_indices[:query_set_size]]
in_context_samples = [full_dataset[i] for i in random_indices[:query_set_size]]
eval_dataset = torch.utils.data.Subset(
full_dataset, random_indices[query_set_size:]
)
return in_context_samples, eval_dataset


def get_context_text(
get_prompt: Callable[[dict], str],
in_context_samples,
effective_num_shots,
num_shots,
get_prompt: Callable[[dict], str],
in_context_samples,
effective_num_shots,
num_shots,
) -> str:
context_text = (
"".join([get_prompt(s) for s in in_context_samples])
Expand All @@ -330,18 +331,18 @@ def sample_batch_demos_from_query_set(query_set, num_samples, batch_size):


def evaluate_coco_flickr(
eval_model,
batch_size,
image_dir_path,
annotations_json_path,
seed=42,
max_generation_length=20,
num_beams=3,
length_penalty=-2.0,
num_samples=5000,
query_set_size=2048,
num_shots=8,
is_flickr=False,
eval_model,
batch_size,
image_dir_path,
annotations_json_path,
seed=42,
max_generation_length=20,
num_beams=3,
length_penalty=-2.0,
num_samples=5000,
query_set_size=2048,
num_shots=8,
is_flickr=False,
):
"""Evaluate a model on COCO dataset.
Expand Down Expand Up @@ -371,8 +372,7 @@ def evaluate_coco_flickr(
is_flickr=is_flickr,
)
effective_num_shots = num_shots if num_shots > 0 else 2
random_indices = get_random_indices(num_samples, query_set_size,
full_dataset, seed)
random_indices = get_random_indices(num_samples, query_set_size, full_dataset, seed)

in_context_samples, eval_dataset = prepare_eval_samples_and_dataset(
full_dataset=full_dataset,
Expand All @@ -387,8 +387,7 @@ def get_prompt(sample):

desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"

for batch in more_itertools.chunked(tqdm(eval_dataset, desc=desc),
batch_size):
for batch in more_itertools.chunked(tqdm(eval_dataset, desc=desc), batch_size):
batch_demo_samples = sample_batch_demos_from_query_set(
in_context_samples, effective_num_shots, len(batch)
)
Expand Down Expand Up @@ -419,8 +418,7 @@ def get_prompt(sample):
)

new_predictions = [
postprocess_captioning_generation(out).replace('"', "") for out in
outputs
postprocess_captioning_generation(out).replace('"', "") for out in outputs
]

for i, sample in enumerate(batch):
Expand Down Expand Up @@ -458,19 +456,19 @@ def get_prompt(sample):


def evaluate_vqa(
eval_model,
batch_size,
image_dir_path,
questions_json_path,
annotations_json_path,
seed=42,
max_generation_length=5,
num_beams=3,
length_penalty=-2.0,
num_samples=5000,
query_set_size=2048,
num_shots=8,
vqa_dataset="vqa",
eval_model,
batch_size,
image_dir_path,
questions_json_path,
annotations_json_path,
seed=42,
max_generation_length=5,
num_beams=3,
length_penalty=-2.0,
num_samples=5000,
query_set_size=2048,
num_shots=8,
vqa_dataset="vqa",
):
"""
Evaluate a model on VQA datasets. Currently supports VQA v2.0.
Expand Down Expand Up @@ -508,8 +506,7 @@ def evaluate_vqa(
f"num_samples + num_shots must be less than or equal to {len(full_dataset)}"
)

random_indices = get_random_indices(num_samples, query_set_size,
full_dataset, seed)
random_indices = get_random_indices(num_samples, query_set_size, full_dataset, seed)

def get_prompt(sample, train=True):
return f"<image>Question:{sample['question'].strip()} Short Answer:{sample['answers'][0].strip() if train else ''}{'<|endofchunk|>' if train else ''}"
Expand All @@ -523,7 +520,7 @@ def get_prompt(sample, train=True):
predictions = []

for batch in more_itertools.chunked(
tqdm(eval_dataset, desc="Running inference"), batch_size
tqdm(eval_dataset, desc="Running inference"), batch_size
):
batch_demo_samples = sample_batch_demos_from_query_set(
in_context_samples, effective_num_shots, len(batch)
Expand Down Expand Up @@ -585,22 +582,22 @@ def get_prompt(sample, train=True):
return acc


def rindex(lst, sublist):
"""Find the starting index *from right* of sublist in lst."""
sublist_len = len(sublist)
for i in range(len(lst) - sublist_len):
if lst[i:i+sublist_len] == sublist:
return i
raise ValueError
def find_sub_list(sl,l):
results=[]
sll=len(sl)
for ind in (i for i,e in enumerate(l) if e==sl[0]):
if l[ind:ind+sll]==sl:
results.append(ind+sll-1)
return results


def evaluate_imagenet(
eval_model,
batch_size: int,
imagenet_root: str,
seed: int = 42,
num_samples: int = 5000,
num_shots: int = 8,
eval_model,
batch_size: int,
imagenet_root: str,
seed: int = 42,
num_samples: int = 5000,
num_shots: int = 8,
):
"""
Evaluate a model on ImageNet dataset.
Expand All @@ -621,7 +618,6 @@ def evaluate_imagenet(
"evaluate_imagenet is currently only supported for OpenFlamingo " "models"
)
model, tokenizer = eval_model.model, eval_model.tokenizer
assert isinstance(model, Flamingo)

train_dataset = ImageNetDataset(os.path.join(imagenet_root, 'train'))
val_dataset = ImageNetDataset(os.path.join(imagenet_root, 'val'))
Expand All @@ -644,7 +640,6 @@ def evaluate_imagenet(
+ [eval_model.image_processor(batch['image']).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
# model._encode_vision_x(vision_x)

overall_probs = []
for imagenet_class_name in tqdm(openai_imagenet_classnames):
Expand All @@ -654,11 +649,10 @@ def evaluate_imagenet(
context_class_names = [in_context_samples[i]['class_name']
for i in range(effective_num_shots)]
text = ''.join(f"{prompt_text} {classname}<|endofchunk|>"
for classname in context_class_names)
for classname in context_class_names)
text += f'{prompt_text} {imagenet_class_name}'
prompt_tokens = tokenizer(prompt_text, add_special_tokens=False,
return_tensors='np')[
'input_ids'].ravel().tolist()
return_tensors='np')['input_ids'].ravel().tolist()

lang_x = tokenizer([text], return_tensors="pt")

Expand All @@ -678,11 +672,10 @@ def evaluate_imagenet(

probs = []
for input_sentence, input_probs in zip(input_ids, gen_probs):
prompt_start = rindex(
input_sentence.detach().cpu().numpy().tolist(),
prompt_tokens)
prompt_end = prompt_start + len(prompt_tokens)
input_probs = input_probs[prompt_end + 1:]
idxes = find_sub_list(prompt_tokens,
input_sentence.detach().cpu().numpy().tolist())
# input_sentence = input_sentence[idxes[-1] + 1:]
input_probs = input_probs[idxes[-1] + 1:]
probs.append(torch.prod(input_probs).item())
overall_probs.append(probs)

Expand All @@ -693,9 +686,9 @@ def evaluate_imagenet(
if batch['class_name'] in top5:
acc5 += 1
print('eval {}/{}: acc@1 ({}), acc@5 ({})'.format(i, num_samples,
acc1 / (i + 1),
acc5 / (i + 1)))
if i >= num_samples - 1:
acc1 / (i+1),
acc5 / (i+1)))
if i >= num_samples:
break

return acc1
Expand Down

0 comments on commit d05fe43

Please sign in to comment.