Skip to content

Commit

Permalink
reinitialize past key values at each iteration/class
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgard committed Mar 3, 2023
1 parent 9a32fff commit ac231ee
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,11 @@ def get_imagenet_prompt(x: dict, is_context: bool = True) -> str:
:context_len]).item()

logits = context_precomputed.logits.clone()
past_key_values = context_precomputed.past_key_values.clone()

# Clone the nested structure of the past key values
past_key_values = tuple(
[tuple([x.clone() for x in inner]) for inner in
context_precomputed.past_key_values])

# Autoregressively compute the outputs without recomputing the
# context computations.
Expand Down

0 comments on commit ac231ee

Please sign in to comment.