Skip to content

Commit

Permalink
Fix CIs for PyTorch 1.13 (huggingface#20686)
Browse files Browse the repository at this point in the history
* fix 1

* fix 2

* fix 3

* fix 4

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 8, 2022
1 parent bcc069d commit e3cc448
Show file tree
Hide file tree
Showing 15 changed files with 18 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2738,7 +2738,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ def forward(
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]

if not return_dict:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,8 @@ def forward(
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]

if not return_dict:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/groupvit/modeling_groupvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,8 @@ def forward(
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]

if not return_dict:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2608,7 +2608,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mvp/modeling_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2982,7 +2982,7 @@ def forward(
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
Expand Down

0 comments on commit e3cc448

Please sign in to comment.