diff --git a/tests/cpp/util/test_nvrtc.cpp b/tests/cpp/util/test_nvrtc.cpp index d41084449e..dab945ecf0 100644 --- a/tests/cpp/util/test_nvrtc.cpp +++ b/tests/cpp/util/test_nvrtc.cpp @@ -9,10 +9,19 @@ #include +#include "util/cuda_runtime.h" #include "util/rtc.h" using namespace transformer_engine; +TEST(UtilTest, CUDAHeaders) { + if (!rtc::is_enabled()) { + GTEST_SKIP() << "NVRTC not enabled, skipping tests"; + } + EXPECT_FALSE(cuda::include_directory().empty()); + EXPECT_EQ(cuda::include_directory_version(), CUDART_VERSION); +} + TEST(UtilTest, NVRTC) { if (!rtc::is_enabled()) { GTEST_SKIP() << "NVRTC not enabled, skipping tests"; diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 706c237ccc..4eb24ec62a 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -2,15 +2,11 @@ global: extern "C++" { nvte_*; - transformer_engine::cuda::sm_count*; - transformer_engine::cuda::sm_arch*; - transformer_engine::cuda::supports_multicast*; - transformer_engine::cuda::stream_priority_range*; - transformer_engine::cuda::current_device*; + transformer_engine::cuda::*; transformer_engine::cuda_driver::get_symbol*; transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; - *transformer_engine::rtc*; + *transformer_engine::rtc::*; transformer_engine::nvte_cudnn_handle_init*; transformer_engine::nvte_cublas_handle_init*; transformer_engine::typeToSize*; diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 4b43940a51..504d761bb1 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include "../common.h" @@ -202,6 +203,49 @@ const std::string &include_directory(bool required) { return path; } +int include_directory_version(bool required) { + // Header path + const auto &include_dir = cuda::include_directory(false); + if (include_dir.empty()) { + if (required) { + NVTE_ERROR( + "Could not detect version of CUDA Toolkit headers " + "(CUDA Toolkit headers not found)."); + } + return -1; + } + + // Parse CUDART_VERSION from cuda_runtime_api.h. + const auto header_path = std::filesystem::path(include_dir) / "cuda_runtime_api.h"; + std::ifstream header_file(header_path); + if (header_file.is_open()) { + const std::string define_prefix = "#define CUDART_VERSION "; + std::string line; + while (std::getline(header_file, line)) { + const auto pos = line.find(define_prefix); + if (pos == std::string::npos) { + continue; + } + try { + const int version = std::stoi(line.substr(pos + define_prefix.size())); + if (version > 0) { + return version; + } + } catch (...) { + continue; + } + } + } + + if (required) { + NVTE_ERROR( + "Could not detect version of CUDA Toolkit headers " + "(Could not parse CUDART_VERSION from ", + header_path.string(), ")."); + } + return -1; +} + int cudart_version() { auto get_version = []() -> int { int version; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index f0aa239622..0f35594001 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -67,6 +67,21 @@ bool supports_multicast(int device_id = -1); */ const std::string &include_directory(bool required = false); +/* \brief Version number of CUDA Toolkit headers + * + * The headers are accessed at run-time and its CUDA version may + * differ from compile-time and from the CUDA Runtime. The header path + * can be configured by setting NVTE_CUDA_INCLUDE_DIR in the + * environment (default is to search in common install paths). + * + * \param[in] required Whether to throw exception if headers are not + * found or if version cannot be determined. + * + * \return CUDA version encoded as major * 1000 + minor * 10, or -1 if + * it could not be determined. + */ +int include_directory_version(bool required = false); + /* \brief CUDA Runtime version number at run-time * * Versions may differ between compile-time and run-time. diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index 7925fdceea..70024a202c 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -12,6 +12,7 @@ #include "../common.h" #include "../util/cuda_driver.h" +#include "../util/cuda_runtime.h" #include "../util/string.h" #include "../util/system.h" @@ -175,14 +176,46 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& const nvrtcResult compile_result = nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data()); if (compile_result != NVRTC_SUCCESS) { - // Display log if compilation failed - std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n"); + std::string log; + + // Decode CUDA version number to "major.minor" string + auto version_string = [](int v) -> std::string { + if (v < 0) { + return ""; + } + return concat_strings(v / 1000, ".", (v % 1000) / 10); + }; + + // Check CUDA versions + const int build_version = CUDA_VERSION; + int nvrtc_version = -1; + int nvrtc_version_major = 0, nvrtc_version_minor = 0; + if (nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor) == NVRTC_SUCCESS) { + nvrtc_version = nvrtc_version_major * 1000 + nvrtc_version_minor * 10; + } + const int header_version = cuda::include_directory_version(); + log += concat_strings("Compile-time CUDA version: ", version_string(build_version), "\n", + "Run-time NVRTC version: ", version_string(nvrtc_version), "\n", + "Run-time CUDA headers version: ", version_string(header_version), "\n"); + if (nvrtc_version != header_version) { + log += concat_strings( + "\nWarning: CUDA versions do not match between NVRTC and CUDA headers (", + cuda::include_directory(), + "). " + "Consider changing the CUDA header search path (by setting NVTE_CUDA_INCLUDE_DIR) " + "or the linked CUDA Runtime (by setting CUDA_HOME or LD_LIBRARY_PATH).\n\n"); + } + + // Get build log + log += concat_strings("NVRTC compilation log for ", filename, ":\n"); const size_t log_offset = log.size(); size_t log_size; NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); log.resize(log_offset + log_size); NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset])); log.back() = '\n'; + + // Display log and throw error std::cerr << log; NVTE_CHECK_NVRTC(compile_result); }