Skip to content

Commit

Permalink
check dynamic requirements by import rather than by pip package
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed May 8, 2024
1 parent 88ec967 commit 07cac23
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
"""
from setuptools import setup

from sony_custom_layers import pinned_requirements
from sony_custom_layers import pinned_pip_requirements

extras_require = {
'torch': pinned_requirements['torch'] + pinned_requirements['torch_ort'],
'tf': pinned_requirements['tf'],
'torch': pinned_pip_requirements['torch'] + pinned_pip_requirements['torch_ort'],
'tf': pinned_pip_requirements['tf'],
}

setup(extras_require=extras_require)
6 changes: 4 additions & 2 deletions sony_custom_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
# -----------------------------------------------------------------------------

# minimal requirements for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
requirements = {
# library names are the names that are used in import statement rather than pip package name, as a library can have
# multiple providing packages per arch, device etc.
required_libraries = {
'tf': ['tensorflow>=2.10'],
'torch': ['torch>=2.0', 'torchvision>=0.15'],
'torch_ort': ['onnx', 'onnxruntime', 'onnxruntime_extensions>=0.8.0'],
}

# pinned requirements of latest tested versions for extra_requires
pinned_requirements = {
pinned_pip_requirements = {
'tf': ['tensorflow==2.15.*'],
'torch': ['torch==2.2.*', 'torchvision==0.17.*'],
'torch_ort': ['onnx==1.15.*', 'onnxruntime==1.17.*', 'onnxruntime_extensions==0.10.*']
Expand Down
6 changes: 3 additions & 3 deletions sony_custom_layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
# -----------------------------------------------------------------------------

from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements
from sony_custom_layers.util.import_util import validate_installed_libraries
from sony_custom_layers import required_libraries

validate_pip_requirements(requirements['tf'])
validate_installed_libraries(required_libraries['tf'])

from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter # noqa: E402
from .custom_objects import custom_layers_scope # noqa: E402
Expand Down
8 changes: 4 additions & 4 deletions sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# -----------------------------------------------------------------------------
from typing import Optional, TYPE_CHECKING

from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements
from sony_custom_layers.util.import_util import validate_installed_libraries
from sony_custom_layers import required_libraries

if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']

validate_pip_requirements(requirements['torch'])
validate_installed_libraries(required_libraries['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402

Expand Down Expand Up @@ -74,7 +74,7 @@ def load_custom_ops(load_ort: bool = False,
```
"""
if load_ort or ort_session_ops:
validate_pip_requirements(requirements['torch_ort'])
validate_installed_libraries(required_libraries['torch_ort'])

# trigger onnxruntime op registration
from .object_detection import nms_ort
Expand Down
38 changes: 25 additions & 13 deletions sony_custom_layers/util/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
import importlib
import warnings
from typing import List, Union

from packaging.requirements import Requirement
from packaging.version import parse
from importlib import metadata


class RequirementError(Exception):
pass


def validate_pip_requirements(requirements: List[str]):
def validate_installed_libraries(requirements: List[str]):
"""
Validate that all requirements are installed and meet the version specifications.
Validate that all required libraries are installed and meet the version specifications.
We import the required libraries and obtain the version from __version__, rather than looking at the installed
pip package, since a single library can be provided by different pip packages per arch, device, etc.
(for example 'import onnxruntime' is provided by both onnxruntime and onnxruntime-gpu packages).
Args:
requirements: a list of pip-style requirement strings
requirements (list): a list of pip-style-like requirement strings with the package name being the library name
that is used in the import statement.
Raises:
RequirementError if any required package is not installed or doesn't meet the version specification
RequirementError if any required library is not installed or doesn't meet the version specification
"""
error = ''
for req_str in requirements:
req = Requirement(req_str)
try:
installed_ver = metadata.version(req.name)
except metadata.PackageNotFoundError:
error += f'\nRequired package {req_str} is not installed'
mod = importlib.import_module(req.name)
except ImportError:
error += f"\nRequired library '{req.name}' is not installed."
continue

if parse(installed_ver) not in req.specifier:
error += f'\nRequired {req_str}, installed version {installed_ver}'
if req.specifier:
try:
installed_ver = mod.__version__
except AttributeError:
warnings.warn(f"Failed to retrieve '{req.name}' version. Continuing without compatability check.")
continue
if parse(installed_ver) not in req.specifier:
error += f"\nRequired '{req.name}' version {req.specifier}, installed version {installed_ver}."

if error:
raise RequirementError(error)

Expand All @@ -53,14 +64,15 @@ def is_compatible(requirements: Union[str, List]) -> bool:
"""
Non-raising requirement(s) check
Args:
requirements (str, List): requirement pip-style string or a list of requirement strings
requirements (str, List): a pip-style-like requirement string with the package name being the library name that
is used in the import statement, or a list of such requirement strings.
Returns:
(bool) whether requirement(s) are satisfied
"""
requirements = [requirements] if isinstance(requirements, str) else requirements
try:
validate_pip_requirements(requirements)
validate_installed_libraries(requirements)
except RequirementError:
return False
return True

0 comments on commit 07cac23

Please sign in to comment.