Skip to content

Commit

Permalink
Fix transformers.utils.fx compatibility with torch<2.0 (huggingface#2…
Browse files Browse the repository at this point in the history
…8774)

guard sdpa on torch>=2.0
  • Loading branch information
fxmarty authored Jan 30, 2024
1 parent 5c8d941 commit 6f7d5db
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
Expand Down Expand Up @@ -608,13 +609,17 @@ def to_concrete(t):
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
}

if is_torch_greater_or_equal_than_2_0:
_MANUAL_META_OVERRIDES[
torch.nn.functional.scaled_dot_product_attention
] = torch_nn_functional_scaled_dot_product_attention


class HFProxy(Proxy):
"""
Expand Down

0 comments on commit 6f7d5db

Please sign in to comment.