diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 66236e98645e6f..02bf2421f4d2f6 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -234,7 +234,7 @@ def is_torch_support_available(self) -> bool: if is_torch_available(): from transformers.utils import get_torch_version - return get_torch_version() >= self.torch_onnx_minimum_version + return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version else: return False diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 796fa1b3ea6a2f..e160cd77f9a323 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest +from packaging import version from parameterized import parameterized from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available @@ -321,7 +322,7 @@ def _onnx_export( if is_torch_available(): from transformers.utils import get_torch_version - if get_torch_version() < onnx_config.torch_onnx_minimum_version: + if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" @@ -364,7 +365,7 @@ def _onnx_export_encoder_decoder_models( if is_torch_available(): from transformers.utils import get_torch_version - if get_torch_version() < onnx_config.torch_onnx_minimum_version: + if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"