Skip to content

Commit

Permalink
update attention masking and tokens slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Mar 3, 2023
1 parent bc3333d commit a5d820b
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def get_imagenet_prompt(x: dict, is_context: bool = True) -> str:
context_encodings = tokenizer([context_text] * batch_size,
**tokenizer_kwargs)
context_ids = context_encodings['input_ids'].to(device)
context_tokens_len = context_ids.shape[-1]
context_len = context_ids.shape[-1]
context_precomputed = model(None, context_ids,
use_cached_vision_x=True,
clear_conditioned_layers=False,
Expand All @@ -784,7 +784,7 @@ def get_imagenet_prompt(x: dict, is_context: bool = True) -> str:

# full_batch_input_ids has shape [batch_size, seq_len], but we
# only need to run inference on the [batch_size,
# context_tokens_len:] inputs that have not been precomputed and
# context_len:] inputs that have not been precomputed and
# vary per class.
full_batch_input_ids = full_batch_encodings["input_ids"].to(device)
full_batch_attention_mask = full_batch_encodings[
Expand All @@ -794,20 +794,19 @@ def get_imagenet_prompt(x: dict, is_context: bool = True) -> str:
# Sanity check that the encoded inputs with context are the same
# as the encoded context alone, for every example in the batch
assert torch.all(context_ids[0, :] == full_batch_input_ids[:,
:context_tokens_len]).item()
:context_len]).item()

# Autoregressively compute the outputs without recomputing the
# context computations.
for i in range(context_tokens_len, seq_len):
# token_ids = full_batch_input_ids[:, i]
# attention_mask = full_batch_attention_mask[:, i]
outputs = model(vision_x=None,
lang_x=full_batch_input_ids[:seq_len+i],
attention_mask=full_batch_attention_mask[:seq_len+i],
use_cached_vision_x=True,
clear_conditioned_layers=False,
past_key_values=past_key_values,
use_cache=True)
for i in range(context_len, seq_len):
outputs = model(
vision_x=None,
lang_x=torch.unsqueeze(full_batch_input_ids[:, i], 1),
attention_mask=full_batch_attention_mask[:, :i],
use_cached_vision_x=True,
clear_conditioned_layers=False,
past_key_values=past_key_values,
use_cache=True)
past_key_values = outputs.past_key_values

# TODO(jpgard): check shape of output logits at this step to
Expand Down

0 comments on commit a5d820b

Please sign in to comment.