Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/cpp/util/test_nvrtc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@

#include <gtest/gtest.h>

#include "util/cuda_runtime.h"
#include "util/rtc.h"

using namespace transformer_engine;

TEST(UtilTest, CUDAHeaders) {
if (!rtc::is_enabled()) {
GTEST_SKIP() << "NVRTC not enabled, skipping tests";
}
EXPECT_FALSE(cuda::include_directory().empty());
EXPECT_EQ(cuda::include_directory_version(), CUDART_VERSION);
}

TEST(UtilTest, NVRTC) {
if (!rtc::is_enabled()) {
GTEST_SKIP() << "NVRTC not enabled, skipping tests";
Expand Down
8 changes: 2 additions & 6 deletions transformer_engine/common/libtransformer_engine.version
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
global:
extern "C++" {
nvte_*;
transformer_engine::cuda::sm_count*;
transformer_engine::cuda::sm_arch*;
transformer_engine::cuda::supports_multicast*;
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda::*;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Wildcard export broadens ABI surface

Changing the explicit list of transformer_engine::cuda::sm_count*, sm_arch*, etc. to the blanket transformer_engine::cuda::* will export every symbol in the cuda namespace — including any future internal helpers — as part of the library's public ABI. This may be intentional to avoid updating the linker script whenever a new function is added, but it is worth confirming the trade-off, since once a symbol is exported it must be maintained for ABI compatibility.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we shouldn't do that. Also, why do we need to modify this file in this PR at all? The newly added function is not part of the public API anyway.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — the new include_directory_version() function is only used internally (in rtc.cpp for error reporting and in the test suite). Since the tests link against the same library, the symbol doesn't need to be exported in the version script.

The wildcard change appears to be unnecessary. You can revert the modifications to libtransformer_engine.version entirely — the explicit symbol list should work as-is without any changes for this PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to expose the utility functions if we want to test them individually in the C++ unit tests.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, this actually shows a separate issue (that we should probably tackle in a separate PR and merge this one without the changes to the build and without the unit test) that the C++ tests as they are currently do not have the visibility into the internal functions. On one hand that is good, since it makes it possible to test that the API surface is correct and that we can do everything we want by just calling the right exposed functions, but on the other it makes it impossible to have targeted tests for those internal implementations without exposing them to the world. I think we should rethink this relationship to have either one or the other as the target of the C++ unit tests (and I would argue that the framework level integration already tackles the API exposure testing so we could make the C++ tests more coupled with the internals of the library itself).

transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
*transformer_engine::rtc::*;
transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*;
Expand Down
44 changes: 44 additions & 0 deletions transformer_engine/common/util/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cublasLt.h>

#include <filesystem>
#include <fstream>
#include <mutex>

#include "../common.h"
Expand Down Expand Up @@ -202,6 +203,49 @@ const std::string &include_directory(bool required) {
return path;
}

int include_directory_version(bool required) {
// Header path
const auto &include_dir = cuda::include_directory(false);
if (include_dir.empty()) {
if (required) {
NVTE_ERROR(
"Could not detect version of CUDA Toolkit headers "
"(CUDA Toolkit headers not found).");
}
return -1;
}

// Parse CUDART_VERSION from cuda_runtime_api.h.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like how we parse the header as a text file. However, when I tried compiling a test program with NVRTC it would override the header's CUDART_VERSION macro with the CUDA Runtime version.

const auto header_path = std::filesystem::path(include_dir) / "cuda_runtime_api.h";
std::ifstream header_file(header_path);
if (header_file.is_open()) {
const std::string define_prefix = "#define CUDART_VERSION ";
std::string line;
while (std::getline(header_file, line)) {
const auto pos = line.find(define_prefix);
if (pos == std::string::npos) {
continue;
}
try {
const int version = std::stoi(line.substr(pos + define_prefix.size()));
if (version > 0) {
return version;
}
} catch (...) {
continue;
}
}
}

if (required) {
NVTE_ERROR(
"Could not detect version of CUDA Toolkit headers "
"(Could not parse CUDART_VERSION from ",
header_path.string(), ").");
}
return -1;
}
Comment on lines +206 to +247
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Result not cached, unlike sibling functions

include_directory_version() re-opens and re-parses cuda_runtime_api.h on every call. The sibling functions cudart_version() and cublas_version() in the same file both cache their result with a static local variable. While the function is currently only exercised on the failure path in rtc.cpp, caching the result would be consistent with the established pattern and avoids repeated I/O if the function is ever called more broadly (e.g. from tests):

int include_directory_version(bool required) {
  static int cached = [&]() -> int {
    // ... existing parsing logic ...
  }();
  // handle `required` separately for the error message
  ...
}

Alternatively, add a static int cached_version = -1; if (cached_version >= 0) return cached_version; guard at the top.


int cudart_version() {
auto get_version = []() -> int {
int version;
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/util/cuda_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ bool supports_multicast(int device_id = -1);
*/
const std::string &include_directory(bool required = false);

/* \brief Version number of CUDA Toolkit headers
*
* The headers are accessed at run-time and its CUDA version may
* differ from compile-time and from the CUDA Runtime. The header path
* can be configured by setting NVTE_CUDA_INCLUDE_DIR in the
* environment (default is to search in common install paths).
*
* \param[in] required Whether to throw exception if headers are not
* found or if version cannot be determined.
*
* \return CUDA version encoded as major * 1000 + minor * 10, or -1 if
* it could not be determined.
*/
int include_directory_version(bool required = false);

/* \brief CUDA Runtime version number at run-time
*
* Versions may differ between compile-time and run-time.
Expand Down
37 changes: 35 additions & 2 deletions transformer_engine/common/util/rtc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/string.h"
#include "../util/system.h"

Expand Down Expand Up @@ -175,14 +176,46 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
const nvrtcResult compile_result =
nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data());
if (compile_result != NVRTC_SUCCESS) {
// Display log if compilation failed
std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n");
std::string log;

// Decode CUDA version number to "major.minor" string
auto version_string = [](int v) -> std::string {
if (v < 0) {
return "<not found>";
}
return concat_strings(v / 1000, ".", (v % 1000) / 10);
};

// Check CUDA versions
const int build_version = CUDA_VERSION;
int nvrtc_version = -1;
int nvrtc_version_major = 0, nvrtc_version_minor = 0;
if (nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor) == NVRTC_SUCCESS) {
nvrtc_version = nvrtc_version_major * 1000 + nvrtc_version_minor * 10;
}
const int header_version = cuda::include_directory_version();
log += concat_strings("Compile-time CUDA version: ", version_string(build_version), "\n",
"Run-time NVRTC version: ", version_string(nvrtc_version), "\n",
"Run-time CUDA headers version: ", version_string(header_version), "\n");
if (nvrtc_version != header_version) {
log += concat_strings(
"\nWarning: CUDA versions do not match between NVRTC and CUDA headers (",
cuda::include_directory(),
"). "
"Consider changing the CUDA header search path (by setting NVTE_CUDA_INCLUDE_DIR) "
"or the linked CUDA Runtime (by setting CUDA_HOME or LD_LIBRARY_PATH).\n\n");
}

// Get build log
log += concat_strings("NVRTC compilation log for ", filename, ":\n");
const size_t log_offset = log.size();
size_t log_size;
NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
log.resize(log_offset + log_size);
NVTE_CHECK_NVRTC(nvrtcGetProgramLog(program, &log[log_offset]));
log.back() = '\n';

// Display log and throw error
std::cerr << log;
NVTE_CHECK_NVRTC(compile_result);
}
Expand Down
Loading