Skip to content

Commit

Permalink
[ROCm] Add rocm version information
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Nov 25, 2024
1 parent 19a51de commit e8934b9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion jax_plugins/rocm/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
project_name = f"jax-rocm{rocm_version}-plugin"
package_name = f"jax_rocm{rocm_version}_plugin"

# Extract ROCm version from the `ROCM_PATH` environment variable.
default_rocm_path = "/opt/rocm"
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
Expand All @@ -43,7 +48,7 @@ def has_ext_modules(self):
name=project_name,
version=__version__,
cmdclass=_cmdclass,
description="JAX Plugin for AMD GPUs",
description=f"JAX Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
long_description="",
long_description_content_type="text/markdown",
author="Ruturaj4",
Expand Down
7 changes: 6 additions & 1 deletion jax_plugins/rocm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
project_name = f"jax-rocm{rocm_version}-pjrt"
package_name = f"jax_plugins.xla_rocm{rocm_version}"

# Extract ROCm version from the `ROCM_PATH` environment variable.
default_rocm_path = "/opt/rocm"
rocm_path = os.getenv("ROCM_PATH", default_rocm_path)
rocm_detected_version = rocm_path.split('-')[-1] if '-' in rocm_path else "unknown"

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
'version', os.path.join(pkg_path, 'version.py'))
Expand All @@ -41,7 +46,7 @@ def load_version_module(pkg_path):
setup(
name=project_name,
version=__version__,
description="JAX XLA PJRT Plugin for AMD GPUs",
description=f"JAX XLA PJRT Plugin for AMD GPUs (ROCm:{rocm_detected_version})",
long_description="",
long_description_content_type="text/markdown",
author="Ruturaj4",
Expand Down

0 comments on commit e8934b9

Please sign in to comment.