Skip to content

Commit

Permalink
remove repetition along batch dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed May 3, 2023
1 parent b1b591a commit 4d17be3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,7 @@ def evaluate_imagenet(
for i in range(effective_num_shots)]
context_text = ''.join(f"{prompt_text} {classname}<|endofchunk|>"
for classname in context_class_names)
context_encodings = tokenizer([context_text] * batch_size,
**tokenizer_kwargs)
context_encodings = tokenizer([context_text], **tokenizer_kwargs)
context_ids = context_encodings["input_ids"].to(device)
context_len = context_ids.shape[-1]
context_precomputed = model(
Expand All @@ -688,7 +687,8 @@ def evaluate_imagenet(
# Sanity check that the encoded inputs with context are the same
# as the encoded context alone, for every example in the batch
assert torch.all(
context_ids[0, :] == lang_x['input_ids'][:, :context_len]
context_ids[0, :] == \
lang_x['input_ids'][:, :context_len].to(device)
).item()

# Clone the nested structure of the past key values
Expand Down

0 comments on commit 4d17be3

Please sign in to comment.