Skip to content

Commit

Permalink
Fix issue introduced in PR huggingface#23163 (huggingface#23363)
Browse files Browse the repository at this point in the history
* fix

* fix

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored May 15, 2023
1 parent 2958b55 commit 81a73fa
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def is_torch_support_available(self) -> bool:
if is_torch_available():
from transformers.utils import get_torch_version

return get_torch_version() >= self.torch_onnx_minimum_version
return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version
else:
return False

Expand Down
5 changes: 3 additions & 2 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import patch

import pytest
from packaging import version
from parameterized import parameterized

from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
Expand Down Expand Up @@ -321,7 +322,7 @@ def _onnx_export(
if is_torch_available():
from transformers.utils import get_torch_version

if get_torch_version() < onnx_config.torch_onnx_minimum_version:
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
Expand Down Expand Up @@ -364,7 +365,7 @@ def _onnx_export_encoder_decoder_models(
if is_torch_available():
from transformers.utils import get_torch_version

if get_torch_version() < onnx_config.torch_onnx_minimum_version:
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
Expand Down

0 comments on commit 81a73fa

Please sign in to comment.