diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py index 797850bbd0..4107113e81 100644 --- a/fairseq/models/huggingface/hf_gpt2.py +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -99,6 +99,7 @@ def forward( prev_output_tokens, src_lengths=None, incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, ): features = self.extract_features(prev_output_tokens, incremental_state) lm_logits = self.model.lm_head(features) @@ -109,7 +110,7 @@ def extract_features( prev_output_tokens, incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, ): - if incremental_state is not None: + if incremental_state: past = self.get_incremental_state("past") else: past = None @@ -132,7 +133,7 @@ def extract_features( ) last_hidden_states = outputs[0] - if incremental_state is not None: + if incremental_state: self.set_incremental_state(incremental_state, "past", outputs[1]) return last_hidden_states