Skip to content

Commit

Permalink
handle ptx version too new
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Dec 9, 2024
1 parent 9fba2b7 commit a55f322
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
14 changes: 8 additions & 6 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,12 @@ class ObjectCode:
__slots__ = ("_handle", "_backend_version", "_jit_options", "_code_type", "_module", "_loader", "_sym_map")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")


def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
if code_type not in self._supported_code_type:
raise ValueError
_lazy_init()

# handle is assigned during _lazy_load
# handle is assigned during _lazy_load
self._handle = None
self._jit_options = jit_options

Expand All @@ -127,7 +126,7 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
self._sym_map = {} if symbol_mapping is None else symbol_mapping

# TODO: do we want to unload in a finalizer? Probably not..

def _lazy_load_module(self, *args, **kwargs):
if self._handle is not None:
return
Expand All @@ -136,7 +135,6 @@ def _lazy_load_module(self, *args, **kwargs):
# a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
if self._jit_options is not None:
raise ValueError
module = self._module.encode()
self._handle = handle_return(self._loader["file"](self._module))
else:
assert isinstance(self._module, bytes)
Expand All @@ -154,7 +152,12 @@ def _lazy_load_module(self, *args, **kwargs):
0,
)
else: # "old" backend
args = (self._module, len(self._jit_options), list(self._jit_options.keys()), list(self._jit_options.values()))
args = (
self._module,
len(self._jit_options),
list(self._jit_options.keys()),
list(self._jit_options.values()),
)
self._handle = handle_return(self._loader["data"](*args))

@precondition(_lazy_load_module)
Expand All @@ -179,6 +182,5 @@ def get_kernel(self, name):

data = handle_return(self._loader["kernel"](self._handle, name))
return Kernel._from_obj(data, self)


# TODO: implement from_handle()
10 changes: 8 additions & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import sys

try:
from cuda.bindings import driver
from cuda.bindings import driver, nvrtc
except ImportError:
from cuda import cuda as driver

from cuda import nvrtc
import pytest

from cuda.core.experimental import Device, _device
Expand Down Expand Up @@ -65,3 +65,9 @@ def clean_up_cffi_files():
os.remove(f)
except FileNotFoundError:
pass # noqa: SIM105


def can_load_generated_ptx():
_, driver_ver = driver.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
return not nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver
3 changes: 2 additions & 1 deletion cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
# this software and related documentation outside the terms of the EULA
# is strictly prohibited.

import importlib

import pytest
from conftest import can_load_generated_ptx

from cuda.core.experimental import Program


@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
def test_get_kernel():
kernel = """
extern __device__ int B();
Expand Down
10 changes: 1 addition & 9 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@
# is strictly prohibited.

import pytest
from conftest import can_load_generated_ptx

from cuda import cuda, nvrtc
from cuda.core.experimental import Device, Program
from cuda.core.experimental._module import Kernel, ObjectCode


def can_load_generated_ptx():
_, driver_ver = cuda.cuDriverGetVersion()
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
if nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver:
return False
return True


def test_program_init_valid_code_type():
code = 'extern "C" __global__ void my_kernel() {}'
program = Program(code, "c++")
Expand Down

0 comments on commit a55f322

Please sign in to comment.