Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
size_t gran;
CUmulticastObjectProp mcProp = {};
mcProp.numDevices = (*comm)->ar2_nvsize;
mcProp.size = (*comm)->mc_maxsize;
mcProp.size = mc_maxsize;
mcProp.handleTypes =
mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;

Expand Down Expand Up @@ -323,8 +323,9 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
IPCCHECK(ipcSocketClose(&ipcSock));
close(fd);
}
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle,
(CUdeviceptr)(*comm)->mydev);
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, (*comm)->mydev);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, cudev);

CUdeviceptr mc_va;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &mc_va, mc_maxsize, (size_t)0, (CUdeviceptr)0U,
Expand Down Expand Up @@ -692,18 +693,19 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
cudaDeviceProp deviceProp;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device));
bool peer_access_available = false;
bool all_peers_accessible = true;
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
int can_access_peer;
cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i);
if (peer_result == cudaSuccess && can_access_peer) {
peer_access_available = true;
if (peer_result != cudaSuccess || !can_access_peer) {
all_peers_accessible = false;
break;
}
}
}
if (!peer_access_available) {

if (!all_peers_accessible) {
NVTE_ERROR(
"No peer-to-peer access available between GPUs. This platform does not support the "
"GPU-to-GPU "
Expand Down
26 changes: 25 additions & 1 deletion transformer_engine/common/util/cuda_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ bool supports_multicast(int device_id) {
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&]() {
// Multicast requires Hopper (SM 9.0) or newer
if (sm_arch(device_id) < 90) {
cache[device_id] = false;
return;
}
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
// Multicast support requires both CUDA12.1 UMD + KMD
Expand All @@ -128,7 +133,26 @@ bool supports_multicast(int device_id) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
}
cache[device_id] = static_cast<bool>(result);
if (!result) {
cache[device_id] = false;
return;
}

// Verify NVLink/NVSwitch topology by testing multicast granularity query
// This will fail if NVLink is not properly configured or devices are not in the same domain
CUmulticastObjectProp testProp = {};
testProp.numDevices = 1;
testProp.size = 4096; // 4KB test size
testProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
size_t gran;
CUresult gran_result = cuda_driver::call("cuMulticastGetGranularity", &gran, &testProp,
CU_MULTICAST_GRANULARITY_RECOMMENDED);
if (gran_result != CUDA_SUCCESS) {
cache[device_id] = false;
return;
}

cache[device_id] = true;
};
std::call_once(flags[device_id], init);
return cache[device_id];
Expand Down
Loading