diff --git a/src/common/cuda_dr_utils.cc b/src/common/cuda_dr_utils.cc index 59dd936eb685..13f2516d408f 100644 --- a/src/common/cuda_dr_utils.cc +++ b/src/common/cuda_dr_utils.cc @@ -4,12 +4,13 @@ #if defined(XGBOOST_USE_CUDA) #include "cuda_dr_utils.h" -#include // for int32_t -#include // for memset -#include // for make_unique -#include // for call_once -#include // for stringstream -#include // for string +#include // for max +#include // for int32_t +#include // for memset +#include // for make_unique +#include // for call_once +#include // for stringstream +#include // for string #include "common.h" // for safe_cuda #include "cuda_rt_utils.h" // for CurrentDevice @@ -78,7 +79,7 @@ void CuDriverApi::ThrowIfError(CUresult status, StringView fn, std::int32_t line return *cu; } -void GetCuLocation(CUmemLocationType type, CUmemLocation *loc) { +void MakeCuMemLocation(CUmemLocationType type, CUmemLocation *loc) { auto ordinal = curt::CurrentDevice(); loc->type = type; @@ -100,7 +101,7 @@ void GetCuLocation(CUmemLocationType type, CUmemLocation *loc) { CUmemAllocationProp prop; std::memset(&prop, '\0', sizeof(prop)); prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - GetCuLocation(type, &prop.location); + MakeCuMemLocation(type, &prop.location); return prop; } } // namespace xgboost::cudr diff --git a/src/common/cuda_dr_utils.h b/src/common/cuda_dr_utils.h index 7dcdd2ab46a1..9fc8ecc8d42c 100644 --- a/src/common/cuda_dr_utils.h +++ b/src/common/cuda_dr_utils.h @@ -3,8 +3,8 @@ * * @brief Utility for CUDA driver API. * - * We don't link with libcuda.so at build time. The utilities here load the shared object - * at runtime. + * XGBoost doesn't link libcuda.so at build time. The utilities here load the shared + * object at runtime. */ #pragma once @@ -22,6 +22,7 @@ namespace xgboost::cudr { struct CuDriverApi { using Flags = unsigned long long; // NOLINT + // Memroy manipulation functions. using MemGetAllocationGranularityFn = CUresult(size_t *granularity, const CUmemAllocationProp *prop, CUmemAllocationGranularity_flags option); @@ -36,10 +37,10 @@ struct CuDriverApi { using MemUnmapFn = CUresult(CUdeviceptr ptr, size_t size); using MemReleaseFn = CUresult(CUmemGenericAllocationHandle handle); using MemAddressFreeFn = CUresult(CUdeviceptr ptr, size_t size); - + // Error handling using GetErrorString = CUresult(CUresult error, const char **pStr); using GetErrorName = CUresult(CUresult error, const char **pStr); - + // Device attributes using DeviceGetAttribute = CUresult(int *pi, CUdevice_attribute attrib, CUdevice dev); using DeviceGet = CUresult(CUdevice *device, int ordinal); @@ -70,6 +71,9 @@ struct CuDriverApi { [[nodiscard]] CuDriverApi &GetGlobalCuDriverApi(); +/** + * @brief Macro for guarding CUDA driver API calls. + */ #define safe_cu(call) \ do { \ auto __status = (call); \ @@ -86,8 +90,13 @@ inline auto GetAllocGranularity(CUmemAllocationProp const *prop) { return granularity; } -void GetCuLocation(CUmemLocationType type, CUmemLocation* loc); +/** + * @brief Obtain appropriate device ordinal for `CUmemLocation`. + */ +void MakeCuMemLocation(CUmemLocationType type, CUmemLocation* loc); -// Describe the allocation property +/** + * @brief Construct a `CUmemAllocationProp`. + */ [[nodiscard]] CUmemAllocationProp MakeAllocProp(CUmemLocationType type); } // namespace xgboost::cudr diff --git a/src/common/device_vector.cu b/src/common/device_vector.cu index aab364be7380..84f5437dda29 100644 --- a/src/common/device_vector.cu +++ b/src/common/device_vector.cu @@ -38,7 +38,7 @@ GrowOnlyVirtualMemVec::GrowOnlyVirtualMemVec(CUmemLocationType type) CUmemAccessDesc hacc; hacc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - xgboost::cudr::GetCuLocation(CU_MEM_LOCATION_TYPE_HOST_NUMA, &hacc.location); + xgboost::cudr::MakeCuMemLocation(type, &hacc.location); this->access_desc_.push_back(hacc); } }