Skip to content

Commit

Permalink
[Fix] Move attention mask to the model device type (databrickslabs#180)
Browse files Browse the repository at this point in the history
The attention mask needs to be on the same device as the rest of the model and inputs, or else there will be a device mismatch.
  • Loading branch information
BaiqingL authored May 26, 2023
1 parent 3725600 commit 5021d94
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion training/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _forward(self, model_inputs, **generate_kwargs):

generated_sequence = self.model.generate(
input_ids=input_ids.to(self.model.device),
attention_mask=attention_mask,
attention_mask=attention_mask.to(self.model.device),
pad_token_id=self.tokenizer.pad_token_id,
**generate_kwargs,
)
Expand Down

0 comments on commit 5021d94

Please sign in to comment.