diff --git a/cuda_core/cuda/core/_linker.pxd b/cuda_core/cuda/core/_linker.pxd index e50ebb9770..395d70a93b 100644 --- a/cuda_core/cuda/core/_linker.pxd +++ b/cuda_core/cuda/core/_linker.pxd @@ -2,6 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from libcpp.vector cimport vector + +from cuda.bindings cimport cydriver + from ._resource_handles cimport NvJitLinkHandle, CuLinkHandle @@ -9,8 +13,16 @@ cdef class Linker: cdef: NvJitLinkHandle _nvjitlink_handle CuLinkHandle _culink_handle + # _drv_jit_keys/_drv_jit_values are the C arrays handed to cuLinkCreate. + # The driver retains a reference to the optionValues array for the life + # of the CUlinkState (it writes back log-size outputs into its slots), + # so these must live past cuLinkCreate and outlive cuLinkDestroy. + # Declared after _culink_handle so their C++ destructors run AFTER + # cuLinkDestroy executes during tp_dealloc. + vector[cydriver.CUjit_option] _drv_jit_keys + vector[void*] _drv_jit_values bint _use_nvjitlink - object _drv_log_bufs # formatted_options list (driver); None for nvjitlink; cleared in link() + object _drv_log_bufs # formatted_options list (driver); None for nvjitlink str _info_log # decoded log; None until link() or pre-link get_*_log() str _error_log # decoded log; None until link() or pre-link get_*_log() object _options # LinkerOptions diff --git a/cuda_core/cuda/core/_linker.pyx b/cuda_core/cuda/core/_linker.pyx index 09aa9863cd..23dd6142ef 100644 --- a/cuda_core/cuda/core/_linker.pyx +++ b/cuda_core/cuda/core/_linker.pyx @@ -39,9 +39,9 @@ from cuda.core._utils.cuda_utils import ( driver, is_sequence, ) +from cuda.core._utils.version import driver_version ctypedef const char* const_char_ptr -ctypedef void* void_ptr __all__ = ["Linker", "LinkerOptions"] @@ -181,9 +181,20 @@ cdef class Linker: class LinkerOptions: """Customizable options for configuring :class:`Linker`. - Since the linker may choose to use nvJitLink or the driver APIs as the linking backend, - not all options are applicable. When the system's installed nvJitLink is too old (<12.3), - or not installed, the driver APIs (cuLink) will be used instead. + Since the linker may choose either nvJitLink or the driver's ``cuLink*`` + APIs as the backend, not every option is applicable to both backends. The + backend is decided per-:class:`Linker` instance from the installed CUDA + driver major version, nvJitLink's availability and major version, the input + code types, and whether link-time optimization is requested: + + - nvJitLink is used when its major version matches the driver's. + - The driver linker is used when nvJitLink is unavailable or too old + (<12.3), or when its major version differs from the driver's (and no LTO + step is required). + - Linking LTO IRs, or requesting ``link_time_optimization`` / ``ptx``, with + nvJitLink unavailable or with mismatched nvJitLink and driver majors is + unsupported and raises :class:`RuntimeError` at :class:`Linker` + construction time. Attributes ---------- @@ -348,39 +359,39 @@ class LinkerOptions: formatted_options.extend((bytearray(size), size, bytearray(size), size)) option_keys.extend( ( - _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, - _driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, - _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, - _driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER, + driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER, + driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, ) ) if self.arch is not None: arch = self.arch.split("_")[-1].upper() - formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) - option_keys.append(_driver.CUjit_option.CU_JIT_TARGET) + formatted_options.append(getattr(driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}")) + option_keys.append(driver.CUjit_option.CU_JIT_TARGET) if self.max_register_count is not None: formatted_options.append(self.max_register_count) - option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS) + option_keys.append(driver.CUjit_option.CU_JIT_MAX_REGISTERS) if self.time is not None: raise ValueError("time option is not supported by the driver API") if self.verbose: formatted_options.append(1) - option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE) + option_keys.append(driver.CUjit_option.CU_JIT_LOG_VERBOSE) if self.link_time_optimization: formatted_options.append(1) - option_keys.append(_driver.CUjit_option.CU_JIT_LTO) + option_keys.append(driver.CUjit_option.CU_JIT_LTO) if self.ptx: raise ValueError("ptx option is not supported by the driver API") if self.optimization_level is not None: formatted_options.append(self.optimization_level) - option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) + option_keys.append(driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL) if self.debug: formatted_options.append(1) - option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) + option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO) if self.lineinfo: formatted_options.append(1) - option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) + option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO) if self.ftz is not None: warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3) if self.prec_div is not None: @@ -402,8 +413,8 @@ class LinkerOptions: if self.split_compile_extended is not None: raise ValueError("split_compile_extended option is not supported by the driver API") if self.no_cache is True: - formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) - option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE) + formatted_options.append(driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE) + option_keys.append(driver.CUjit_option.CU_JIT_CACHE_MODE) return formatted_options, option_keys @@ -430,7 +441,7 @@ class LinkerOptions: backend = backend.lower() if backend != "nvjitlink": raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'") - if not _use_nvjitlink_backend: + if _probe_nvjitlink() is None: raise RuntimeError("nvJitLink backend is not available") return self._prepare_nvjitlink_options(as_bytes=True) @@ -448,12 +459,30 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc cdef cydriver.CUlinkState c_raw_culink cdef Py_ssize_t c_num_opts, i cdef vector[const_char_ptr] c_str_opts - cdef vector[cydriver.CUjit_option] c_jit_keys - cdef vector[void_ptr] c_jit_values + cdef cydriver.CUjit_option* c_drv_jit_keys_ptr + cdef void** c_drv_jit_values_ptr self._options = options = check_or_create_options(LinkerOptions, options, "Linker options") - if _use_nvjitlink_backend: + # Decide the backend per-instance based on the current environment and this + # Linker's inputs. See _choose_backend() for the full decision matrix. + inputs_have_ltoir = any( + getattr(code, "code_type", None) == "ltoir" for code in object_codes + ) + lto_requested = bool(options.link_time_optimization) or bool(options.ptx) + nvjitlink_version = _probe_nvjitlink() + # Probe driver version lazily: only needed when comparing majors. + # In environments where nvJitLink is installed but the driver is + # absent (e.g., build containers), we can still select nvJitLink. + try: + driver_major = driver_version()[0] + except Exception: + driver_major = None + backend = _choose_backend( + driver_major, nvjitlink_version, inputs_have_ltoir, lto_requested + ) + + if backend == "nvjitlink": self._use_nvjitlink = True options_bytes = options._prepare_nvjitlink_options(as_bytes=True) c_num_opts = len(options_bytes) @@ -471,19 +500,32 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc # the driver writes into via raw pointers during linking operations. self._drv_log_bufs = formatted_options c_num_opts = len(option_keys) - c_jit_keys.resize(c_num_opts) - c_jit_values.resize(c_num_opts) + # Store the option key/value arrays as instance members so they outlive + # the cuLinkCreate call. CUDA driver docs require optionValues to + # remain valid for the life of the CUlinkState when output options are + # used (the driver writes log-fill sizes back into the array). The + # pxd declaration order ensures these vectors are destroyed AFTER + # _culink_handle -- i.e. after cuLinkDestroy has run. + self._drv_jit_keys.resize(c_num_opts) + self._drv_jit_values.resize(c_num_opts) for i in range(c_num_opts): - c_jit_keys[i] = option_keys[i] + self._drv_jit_keys[i] = option_keys[i] val = formatted_options[i] if isinstance(val, bytearray): - c_jit_values[i] = PyByteArray_AS_STRING(val) + self._drv_jit_values[i] = PyByteArray_AS_STRING(val) else: - c_jit_values[i] = int(val) + self._drv_jit_values[i] = int(val) + # Capture the vector data() pointers before entering nogil to keep + # the nogil region free of any attribute access on self. + c_drv_jit_keys_ptr = self._drv_jit_keys.data() + c_drv_jit_values_ptr = self._drv_jit_values.data() try: with nogil: HANDLE_RETURN(cydriver.cuLinkCreate( - c_num_opts, c_jit_keys.data(), c_jit_values.data(), &c_raw_culink)) + c_num_opts, + c_drv_jit_keys_ptr, + c_drv_jit_values_ptr, + &c_raw_culink)) except CUDAError as e: Linker_annotate_error_log(self, e) raise @@ -597,11 +639,12 @@ cdef inline object Linker_link(Linker self, str target_type): raise code = (c_cubin_out)[:c_output_size] - # Linking is complete; cache the decoded log strings and release - # the driver's raw bytearray buffers (no longer written to). + # Linking is complete; cache the decoded log strings. The driver's raw + # bytearray buffers are retained for the lifetime of the CUlinkState + # because cuLinkDestroy may still dereference the log-buffer pointers + # registered via cuLinkCreate. self._info_log = self.get_info_log() self._error_log = self.get_error_log() - self._drv_log_bufs = None return ObjectCode._init(bytes(code), target_type, name=self._options.name) @@ -618,9 +661,10 @@ cdef inline void Linker_annotate_error_log(Linker self, object e): # ============================================================================= # TODO: revisit this treatment for py313t builds -_driver = None # populated if nvJitLink cannot be used _inited = False -_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver() +_nvjitlink_probed = False +_nvjitlink_version = None # (major, minor) if usable; None if unavailable/too old +_nvjitlink_missing_warned = False # Input type mappings populated by _lazy_init() with C-level enum ints. _nvjitlink_input_types = None @@ -632,12 +676,15 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool: return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion")) -# Note: this function is reused in the tests -def _decide_nvjitlink_or_driver() -> bool: - """Return True if falling back to the cuLink* driver APIs.""" - global _driver, _use_nvjitlink_backend - if _use_nvjitlink_backend is not None: - return not _use_nvjitlink_backend +def _probe_nvjitlink() -> tuple | None: + """Return ``(major, minor)`` if nvJitLink is available and >= 12.3, else ``None``. + + Emits a ``RuntimeWarning`` at most once when nvJitLink is unavailable or too + old. The result is cached for subsequent calls. + """ + global _nvjitlink_probed, _nvjitlink_version, _nvjitlink_missing_warned + if _nvjitlink_probed: + return _nvjitlink_version warn_txt_common = ( "the driver APIs will be used instead, which do not support" @@ -649,46 +696,111 @@ def _decide_nvjitlink_or_driver() -> bool: "cuda.bindings.nvjitlink", probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load ) + warn_txt = None if nvjitlink_module is None: - warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." - else: - from cuda.bindings._internal import nvjitlink - - if _nvjitlink_has_version_symbol(nvjitlink): - _use_nvjitlink_backend = True - return False # Use nvjitlink warn_txt = ( - f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." - f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." + f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings." ) + else: + from cuda.bindings._internal import nvjitlink as inner_nvjitlink + + if _nvjitlink_has_version_symbol(inner_nvjitlink): + _nvjitlink_version = tuple(nvjitlink_module.version()) + else: + warn_txt = ( + f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)." + f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink." + ) - warn(warn_txt, stacklevel=2, category=RuntimeWarning) - _use_nvjitlink_backend = False - _driver = driver - return True + if warn_txt is not None and not _nvjitlink_missing_warned: + warn(warn_txt, stacklevel=2, category=RuntimeWarning) + _nvjitlink_missing_warned = True + _nvjitlink_probed = True + return _nvjitlink_version + + +def _choose_backend( + driver_major: int | None, + nvjitlink_version: tuple | None, + inputs_have_ltoir: bool, + lto_requested: bool, +) -> str: + """Choose the linker backend for a specific Linker invocation. + + Parameters + ---------- + driver_major : int or None + Major version of the installed CUDA driver (from ``cuDriverGetVersion``). + ``None`` when the driver cannot be queried (e.g., no driver installed). + nvjitlink_version : tuple[int, int] or None + ``(major, minor)`` if nvJitLink is available and >=12.3; ``None`` otherwise. + inputs_have_ltoir : bool + ``True`` if any input ``ObjectCode`` has ``code_type == "ltoir"``. + lto_requested : bool + ``True`` if ``LinkerOptions.link_time_optimization`` or ``ptx`` is set + (both force the use of nvJitLink; the driver linker cannot emit PTX and + cannot do link-time optimization on LTO IR). + + Returns + ------- + str + ``"nvjitlink"`` or ``"driver"``. + + Raises + ------ + RuntimeError + If the request cannot be satisfied by any backend, for example when + LTO IR inputs or ``link_time_optimization`` are requested but nvJitLink + is unavailable, or when driver and nvJitLink have mismatched major + versions for an LTO link. + """ + needs_nvjitlink = inputs_have_ltoir or lto_requested + + if nvjitlink_version is None: + if needs_nvjitlink: + raise RuntimeError( + "LTO IR input or link-time optimization was requested, but " + "nvJitLink is not available (driver linker cannot perform LTO). " + "Install cuda-bindings with a compatible nvJitLink (>=12.3)." + ) + return "driver" + + nvjitlink_major = nvjitlink_version[0] + # If driver version is unknown, optimistically use nvJitLink + # (common in build containers with nvJitLink but no driver). + if driver_major is None or nvjitlink_major == driver_major: + return "nvjitlink" + + if needs_nvjitlink: + raise RuntimeError( + f"Cannot link with nvJitLink {nvjitlink_major}.x against CUDA driver " + f"{driver_major}.x: LTO IR or link-time optimization requires matching " + f"major versions, and the driver linker cannot perform LTO. " + f"Install an nvJitLink matching the driver major version." + ) + # Driver and nvJitLink have different major versions. nvJitLink output may + # target an architecture or format that the driver cannot load, so fall back + # to the driver's own linker for non-LTO linking. + return "driver" def _lazy_init(): global _inited, _nvjitlink_input_types, _driver_input_types if _inited: return - - _decide_nvjitlink_or_driver() - if _use_nvjitlink_backend: - _nvjitlink_input_types = { - "ptx": cynvjitlink.NVJITLINK_INPUT_PTX, - "cubin": cynvjitlink.NVJITLINK_INPUT_CUBIN, - "fatbin": cynvjitlink.NVJITLINK_INPUT_FATBIN, - "ltoir": cynvjitlink.NVJITLINK_INPUT_LTOIR, - "object": cynvjitlink.NVJITLINK_INPUT_OBJECT, - "library": cynvjitlink.NVJITLINK_INPUT_LIBRARY, - } - else: - _driver_input_types = { - "ptx": cydriver.CU_JIT_INPUT_PTX, - "cubin": cydriver.CU_JIT_INPUT_CUBIN, - "fatbin": cydriver.CU_JIT_INPUT_FATBINARY, - "object": cydriver.CU_JIT_INPUT_OBJECT, - "library": cydriver.CU_JIT_INPUT_LIBRARY, - } + _nvjitlink_input_types = { + "ptx": cynvjitlink.NVJITLINK_INPUT_PTX, + "cubin": cynvjitlink.NVJITLINK_INPUT_CUBIN, + "fatbin": cynvjitlink.NVJITLINK_INPUT_FATBIN, + "ltoir": cynvjitlink.NVJITLINK_INPUT_LTOIR, + "object": cynvjitlink.NVJITLINK_INPUT_OBJECT, + "library": cynvjitlink.NVJITLINK_INPUT_LIBRARY, + } + _driver_input_types = { + "ptx": cydriver.CU_JIT_INPUT_PTX, + "cubin": cydriver.CU_JIT_INPUT_CUBIN, + "fatbin": cydriver.CU_JIT_INPUT_FATBINARY, + "object": cydriver.CU_JIT_INPUT_OBJECT, + "library": cydriver.CU_JIT_INPUT_LIBRARY, + } _inited = True diff --git a/cuda_core/tests/test_linker.py b/cuda_core/tests/test_linker.py index 0d4ff91dcd..c219d4dd94 100644 --- a/cuda_core/tests/test_linker.py +++ b/cuda_core/tests/test_linker.py @@ -7,6 +7,7 @@ from cuda.core import Device, Linker, LinkerOptions, Program, ProgramOptions, _linker from cuda.core._module import ObjectCode from cuda.core._utils.cuda_utils import CUDAError +from cuda.core._utils.version import driver_version ARCH = "sm_" + "".join(f"{i}" for i in Device().compute_capability) @@ -18,7 +19,22 @@ device_function_b = "__device__ int B() { return 0; }" device_function_c = "__device__ int C(int a, int b) { return a + b; }" -is_culink_backend = _linker._decide_nvjitlink_or_driver() + +def _current_env_backend() -> str: + """Return the backend a default (PTX input, no LTO) Linker picks on this machine.""" + try: + drv_major = driver_version()[0] + except Exception: + drv_major = None + return _linker._choose_backend( + drv_major, + _linker._probe_nvjitlink(), + inputs_have_ltoir=False, + lto_requested=False, + ) + + +is_culink_backend = _current_env_backend() == "driver" if not is_culink_backend: from cuda.bindings import nvjitlink @@ -96,7 +112,11 @@ def test_linker_init(compile_ptx_functions, options): def test_linker_init_invalid_arch(compile_ptx_functions): - err = AttributeError if is_culink_backend else nvjitlink.nvJitLinkError + # With the driver backend, ptx=True (which implies link-time optimization) + # cannot be satisfied at all, so dispatch raises RuntimeError before the + # arch string is ever parsed. With the nvJitLink backend, the arch string + # is validated by nvJitLink itself. + err = RuntimeError if is_culink_backend else nvjitlink.nvJitLinkError with pytest.raises(err): options = LinkerOptions(arch="99", ptx=True) Linker(*compile_ptx_functions, options=options) @@ -205,7 +225,7 @@ def test_linker_options_as_bytes_invalid_backend(): def test_linker_options_as_bytes_driver_not_supported(): """Test that as_bytes() is not supported for driver backend""" options = LinkerOptions(arch="sm_80") - with pytest.raises(RuntimeError, match="as_bytes\\(\\) only supports 'nvjitlink' backend"): + with pytest.raises(ValueError, match="as_bytes\\(\\) only supports 'nvjitlink' backend"): options.as_bytes("driver") @@ -242,3 +262,65 @@ def test_linker_options_nvjitlink_options_as_str(): assert f"-arch={ARCH}" in options assert "-g" in options assert "-lineinfo" in options + + +# --------------------------------------------------------------------------- +# Per-instance dispatch tests +# +# The full _choose_backend() decision matrix lives in test_linker_dispatch.py as +# GPU-free unit tests. The tests below drive the same dispatch logic through the +# real Linker constructor (with patched version probes) to confirm that the +# dispatch is invoked before any backend handle is created. +# --------------------------------------------------------------------------- + + +class TestLinkerDispatch: + """Per-instance dispatch exercised by constructing a Linker with patched version probes. + + These tests intercept both :func:`driver_version` (via the name imported into + ``_linker``) and :func:`_probe_nvjitlink` so the decision is deterministic, + then assert that ``Linker.__init__`` raises before creating any backend handle + for the unsatisfiable cases. + """ + + @pytest.fixture + def ltoir_object(self): + # A minimal ObjectCode marked as ltoir is sufficient: _choose_backend runs + # before any backend handle is created, so the payload never reaches the + # linker libraries. + return ObjectCode._init(b"\x00stub-ltoir-payload", "ltoir") + + def test_ltoir_without_nvjitlink_raises(self, monkeypatch, ltoir_object): + monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0)) + monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None) + with pytest.raises(RuntimeError, match="nvJitLink is not available"): + Linker(ltoir_object, options=LinkerOptions(arch=ARCH)) + + def test_cross_major_with_ltoir_raises(self, monkeypatch, ltoir_object): + monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0)) + monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (12, 9)) + with pytest.raises(RuntimeError, match="matching major versions"): + Linker(ltoir_object, options=LinkerOptions(arch=ARCH)) + + @pytest.fixture + def ptx_object(self): + # Stub PTX payload; dispatch raises before the bytes reach any backend. + return ObjectCode._init(b"// stub ptx\n", "ptx") + + def test_cross_major_with_lto_option_raises(self, monkeypatch, ptx_object): + monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0)) + monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: (13, 0)) + with pytest.raises(RuntimeError, match="matching major versions"): + Linker( + ptx_object, + options=LinkerOptions(arch=ARCH, link_time_optimization=True), + ) + + def test_lto_without_nvjitlink_raises(self, monkeypatch, ptx_object): + monkeypatch.setattr(_linker, "driver_version", lambda: (12, 9, 0)) + monkeypatch.setattr(_linker, "_probe_nvjitlink", lambda: None) + with pytest.raises(RuntimeError, match="nvJitLink is not available"): + Linker( + ptx_object, + options=LinkerOptions(arch=ARCH, link_time_optimization=True), + ) diff --git a/cuda_core/tests/test_linker_dispatch.py b/cuda_core/tests/test_linker_dispatch.py new file mode 100644 index 0000000000..d3c4f1b5e4 --- /dev/null +++ b/cuda_core/tests/test_linker_dispatch.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Linker backend dispatch logic. + +These cover the full decision matrix of :func:`cuda.core._linker._choose_backend` +with mocked version and availability inputs, so they run without a GPU, without +a specific nvJitLink version installed, and without a CUDA driver. +""" + +import pytest + +from cuda.core import _linker + + +class TestChooseBackend: + """Parametrized unit tests for :func:`cuda.core._linker._choose_backend`. + + The decision matrix axes are: + + * ``driver_major``: CUDA driver major version. + * ``nvjitlink_version``: ``(major, minor)`` tuple or ``None`` if nvJitLink is + unavailable / too old. + * ``inputs_have_ltoir``: any input ``ObjectCode`` has ``code_type="ltoir"``. + * ``lto_requested``: ``LinkerOptions.link_time_optimization`` or ``ptx``. + """ + + @pytest.mark.parametrize( + ("driver_major", "nvjitlink_version", "has_ltoir", "lto_requested", "expected"), + [ + # No nvJitLink available + no LTO needed -> driver. + (12, None, False, False, "driver"), + (13, None, False, False, "driver"), + # Matching driver/nvJitLink majors -> always nvJitLink. + (12, (12, 3), False, False, "nvjitlink"), + (12, (12, 9), True, True, "nvjitlink"), + (12, (12, 9), True, False, "nvjitlink"), + (12, (12, 9), False, True, "nvjitlink"), + (13, (13, 0), False, False, "nvjitlink"), + (13, (13, 0), True, True, "nvjitlink"), + # Cross-major, no LTO requirement -> driver fallback. + (13, (12, 9), False, False, "driver"), + (12, (13, 0), False, False, "driver"), + # Unknown driver (e.g., build containers) optimistically picks nvJitLink when available. + (None, (12, 9), False, False, "nvjitlink"), + (None, (12, 9), True, False, "nvjitlink"), + (None, (12, 9), False, True, "nvjitlink"), + (None, (13, 0), True, True, "nvjitlink"), + # Unknown driver + no nvJitLink + no LTO -> driver (will fail at use-time, not dispatch). + (None, None, False, False, "driver"), + ], + ) + def test_returns_expected_backend(self, driver_major, nvjitlink_version, has_ltoir, lto_requested, expected): + assert _linker._choose_backend(driver_major, nvjitlink_version, has_ltoir, lto_requested) == expected + + @pytest.mark.parametrize( + ("driver_major", "nvjitlink_version", "has_ltoir", "lto_requested", "match"), + [ + # No nvJitLink + LTO IR input -> cannot satisfy. + (12, None, True, False, "nvJitLink is not available"), + (13, None, True, True, "nvJitLink is not available"), + # No nvJitLink + link_time_optimization requested. + (12, None, False, True, "nvJitLink is not available"), + # Cross-major + LTO IR input. + (13, (12, 9), True, False, "matching major versions"), + (12, (13, 0), True, True, "matching major versions"), + # Cross-major + link_time_optimization requested (no ltoir input). + (13, (12, 9), False, True, "matching major versions"), + (12, (13, 0), False, True, "matching major versions"), + # Unknown driver + no nvJitLink + LTO needs cannot be satisfied. + (None, None, True, False, "nvJitLink is not available"), + (None, None, False, True, "nvJitLink is not available"), + ], + ) + def test_raises_when_unsatisfiable(self, driver_major, nvjitlink_version, has_ltoir, lto_requested, match): + with pytest.raises(RuntimeError, match=match): + _linker._choose_backend(driver_major, nvjitlink_version, has_ltoir, lto_requested) diff --git a/cuda_core/tests/test_optional_dependency_imports.py b/cuda_core/tests/test_optional_dependency_imports.py index 02edcc9839..59b1506e5b 100644 --- a/cuda_core/tests/test_optional_dependency_imports.py +++ b/cuda_core/tests/test_optional_dependency_imports.py @@ -11,23 +11,23 @@ def restore_optional_import_state(): saved_nvvm_module = _program._nvvm_module saved_nvvm_attempted = _program._nvvm_import_attempted - saved_driver = _linker._driver - saved_inited = _linker._inited - saved_use_nvjitlink = _linker._use_nvjitlink_backend + saved_probed = _linker._nvjitlink_probed + saved_version = _linker._nvjitlink_version + saved_warned = _linker._nvjitlink_missing_warned _program._nvvm_module = None _program._nvvm_import_attempted = False - _linker._driver = None - _linker._inited = False - _linker._use_nvjitlink_backend = None + _linker._nvjitlink_probed = False + _linker._nvjitlink_version = None + _linker._nvjitlink_missing_warned = False yield _program._nvvm_module = saved_nvvm_module _program._nvvm_import_attempted = saved_nvvm_attempted - _linker._driver = saved_driver - _linker._inited = saved_inited - _linker._use_nvjitlink_backend = saved_use_nvjitlink + _linker._nvjitlink_probed = saved_probed + _linker._nvjitlink_version = saved_version + _linker._nvjitlink_missing_warned = saved_warned def test_get_nvvm_module_reraises_nested_module_not_found(monkeypatch): @@ -75,7 +75,7 @@ def fake__optional_cuda_import(modname, probe_function=None): _program._get_nvvm_module() -def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch): +def test_probe_nvjitlink_reraises_nested_module_not_found(monkeypatch): def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" assert probe_function is not None @@ -86,11 +86,11 @@ def fake__optional_cuda_import(modname, probe_function=None): monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import) with pytest.raises(ModuleNotFoundError, match="not_a_real_dependency") as excinfo: - _linker._decide_nvjitlink_or_driver() + _linker._probe_nvjitlink() assert excinfo.value.name == "not_a_real_dependency" -def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch): +def test_probe_nvjitlink_warns_and_returns_none_when_module_missing(monkeypatch): def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" assert probe_function is not None @@ -99,7 +99,8 @@ def fake__optional_cuda_import(modname, probe_function=None): monkeypatch.setattr(_linker, "_optional_cuda_import", fake__optional_cuda_import) with pytest.warns(RuntimeWarning, match="cuda.bindings.nvjitlink is not available"): - use_driver_backend = _linker._decide_nvjitlink_or_driver() + probe_result = _linker._probe_nvjitlink() - assert use_driver_backend is True - assert _linker._use_nvjitlink_backend is False + assert probe_result is None + assert _linker._nvjitlink_version is None + assert _linker._nvjitlink_probed is True diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 992ce33655..e8063ab815 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -13,10 +13,26 @@ from cuda.core._module import Kernel, ObjectCode from cuda.core._program import Program, ProgramOptions from cuda.core._utils.cuda_utils import CUDAError, handle_return +from cuda.core._utils.version import driver_version pytest_plugins = ("cuda_python_test_helpers.nvvm_bitcode",) -is_culink_backend = _linker._decide_nvjitlink_or_driver() + +def _default_linker_backend() -> str: + """Backend a default (PTX input, no LTO) Linker picks on this machine.""" + try: + drv_major = driver_version()[0] + except Exception: + drv_major = None + return _linker._choose_backend( + drv_major, + _linker._probe_nvjitlink(), + inputs_have_ltoir=False, + lto_requested=False, + ) + + +is_culink_backend = _default_linker_backend() == "driver" def _is_nvvm_available():