From cb62f9e73bbdf0aafa414938e5b692e70525f0ee Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 21 Jul 2024 05:08:29 +0800 Subject: [PATCH] [EM] Prevent init with CUDA malloc resource. (#10606) --- src/common/cuda_rt_utils.h | 37 +++++++++++++ src/common/ref_resource_view.cuh | 11 +++- src/common/resource.cuh | 6 +-- src/common/timer.cc | 29 +++++----- src/data/ellpack_page.cu | 2 +- src/data/ellpack_page_raw_format.cu | 60 ++++++++++++++------- tests/ci_build/conda_env/macos_cpu_test.yml | 1 + 7 files changed, 105 insertions(+), 41 deletions(-) diff --git a/src/common/cuda_rt_utils.h b/src/common/cuda_rt_utils.h index fa14f8434970..210f1e07d7f8 100644 --- a/src/common/cuda_rt_utils.h +++ b/src/common/cuda_rt_utils.h @@ -3,6 +3,11 @@ */ #pragma once #include // for int32_t + +#if defined(XGBOOST_USE_NVTX) +#include +#endif // defined(XGBOOST_USE_NVTX) + namespace xgboost::common { std::int32_t AllVisibleGPUs(); @@ -18,4 +23,36 @@ bool SupportsAts(); void CheckComputeCapability(); void SetDevice(std::int32_t device); + +struct NvtxDomain { + static constexpr char const *name{"libxgboost"}; // NOLINT +}; + +#if defined(XGBOOST_USE_NVTX) +using NvtxScopedRange = ::nvtx3::scoped_range_in; +using NvtxEventAttr = ::nvtx3::event_attributes; +using NvtxRgb = ::nvtx3::rgb; +#else +class NvtxScopedRange { + public: + template + explicit NvtxScopedRange(Args &&...) {} +}; +class NvtxEventAttr { + public: + template + explicit NvtxEventAttr(Args &&...) {} +}; +class NvtxRgb { + public: + template + explicit NvtxRgb(Args &&...) {} +}; +#endif // defined(XGBOOST_USE_NVTX) } // namespace xgboost::common + +#if defined(XGBOOST_USE_NVTX) +#define xgboost_NVTX_FN_RANGE() NVTX3_FUNC_RANGE_IN(::xgboost::common::NvtxDomain) +#else +#define xgboost_NVTX_FN_RANGE() +#endif // defined(XGBOOST_USE_NVTX) diff --git a/src/common/ref_resource_view.cuh b/src/common/ref_resource_view.cuh index ff311c1409a7..d48b221a305d 100644 --- a/src/common/ref_resource_view.cuh +++ b/src/common/ref_resource_view.cuh @@ -16,10 +16,17 @@ namespace xgboost::common { * @brief Make a fixed size `RefResourceView` with cudaMalloc resource. */ template -[[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(Context const* ctx, - std::size_t n_elements, T const& init) { +[[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(Context const*, + std::size_t n_elements) { auto resource = std::make_shared(n_elements * sizeof(T)); auto ref = RefResourceView{resource->DataAs(), n_elements, resource}; + return ref; +} + +template +[[nodiscard]] RefResourceView MakeFixedVecWithCudaMalloc(Context const* ctx, + std::size_t n_elements, T const& init) { + auto ref = MakeFixedVecWithCudaMalloc(ctx, n_elements); thrust::fill_n(ctx->CUDACtx()->CTP(), ref.data(), ref.size(), init); return ref; } diff --git a/src/common/resource.cuh b/src/common/resource.cuh index 90b9756a9fc2..e950a8d90695 100644 --- a/src/common/resource.cuh +++ b/src/common/resource.cuh @@ -24,11 +24,9 @@ class CudaMallocResource : public ResourceHandler { } ~CudaMallocResource() noexcept(true) override { this->Clear(); } - void* Data() override { return storage_.data(); } + [[nodiscard]] void* Data() override { return storage_.data(); } [[nodiscard]] std::size_t Size() const override { return storage_.size(); } - void Resize(std::size_t n_bytes, std::byte init = std::byte{0}) { - this->storage_.resize(n_bytes, init); - } + void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); } }; class CudaMmapResource : public ResourceHandler { diff --git a/src/common/timer.cc b/src/common/timer.cc index 9b1f49fbd5c8..0b55d1623dbc 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -6,9 +6,10 @@ #include #include "../collective/communicator-inl.h" +#include "cuda_rt_utils.h" #if defined(XGBOOST_USE_NVTX) -#include +#include #endif // defined(XGBOOST_USE_NVTX) namespace xgboost::common { @@ -17,8 +18,8 @@ void Monitor::Start(std::string const &name) { auto &stats = statistics_map_[name]; stats.timer.Start(); #if defined(XGBOOST_USE_NVTX) - std::string nvtx_name = "xgboost::" + label_ + "::" + name; - stats.nvtx_id = nvtxRangeStartA(nvtx_name.c_str()); + auto range_handle = nvtx3::start_range_in(label_ + "::" + name); + stats.nvtx_id = range_handle.get_value(); #endif // defined(XGBOOST_USE_NVTX) } } @@ -29,34 +30,32 @@ void Monitor::Stop(const std::string &name) { stats.timer.Stop(); stats.count++; #if defined(XGBOOST_USE_NVTX) - nvtxRangeEnd(stats.nvtx_id); + nvtx3::end_range_in(nvtx3::range_handle{stats.nvtx_id}); #endif // defined(XGBOOST_USE_NVTX) } } -void Monitor::PrintStatistics(StatMap const& statistics) const { +void Monitor::PrintStatistics(StatMap const &statistics) const { for (auto &kv : statistics) { if (kv.second.first == 0) { - LOG(WARNING) << - "Timer for " << kv.first << " did not get stopped properly."; + LOG(WARNING) << "Timer for " << kv.first << " did not get stopped properly."; continue; } - LOG(CONSOLE) << kv.first << ": " << static_cast(kv.second.second) / 1e+6 - << "s, " << kv.second.first << " calls @ " - << kv.second.second - << "us" << std::endl; + LOG(CONSOLE) << kv.first << ": " << static_cast(kv.second.second) / 1e+6 << "s, " + << kv.second.first << " calls @ " << kv.second.second << "us" << std::endl; } } void Monitor::Print() const { - if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { return; } + if (!ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + return; + } auto rank = collective::GetRank(); StatMap stat_map; for (auto const &kv : statistics_map_) { stat_map[kv.first] = std::make_pair( - kv.second.count, std::chrono::duration_cast( - kv.second.timer.elapsed) - .count()); + kv.second.count, + std::chrono::duration_cast(kv.second.timer.elapsed).count()); } if (stat_map.empty()) { return; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 7d3f4c820a22..fc28b7c56f12 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -404,7 +404,7 @@ size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bs LOG(FATAL) << "Concatenating the same Ellpack."; return this->n_rows * this->row_stride; } - dh::LaunchN(num_elements, CopyPage{this, page, offset}); + dh::LaunchN(num_elements, ctx->CUDACtx()->Stream(), CopyPage{this, page, offset}); monitor_.Stop(__func__); return num_elements; } diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 3f23c5d8d3d6..86d1ac6da7eb 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -6,6 +6,7 @@ #include // for size_t #include // for vector +#include "../common/cuda_rt_utils.h" #include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream #include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc #include "../common/ref_resource_view.h" // for ReadVec, WriteVec @@ -21,6 +22,8 @@ namespace { template [[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi, common::RefResourceView* vec) { + xgboost_NVTX_FN_RANGE(); + std::uint64_t n{0}; if (!fi->Read(&n)) { return false; @@ -37,7 +40,7 @@ template } auto ctx = Context{}.MakeCUDA(common::CurrentDevice()); - *vec = common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast(0)); + *vec = common::MakeFixedVecWithCudaMalloc(&ctx, n); dh::safe_cuda(cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream())); return true; } @@ -50,6 +53,7 @@ template [[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, common::AlignedResourceReadStream* fi) { + xgboost_NVTX_FN_RANGE(); auto* impl = page->Impl(); impl->SetCuts(this->cuts_); @@ -69,6 +73,8 @@ template [[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, common::AlignedFileWriteStream* fo) { + xgboost_NVTX_FN_RANGE(); + std::size_t bytes{0}; auto* impl = page.Impl(); bytes += fo->Write(impl->n_rows); @@ -84,22 +90,30 @@ template } [[nodiscard]] bool EllpackPageRawFormat::Read(EllpackPage* page, EllpackHostCacheStream* fi) const { + xgboost_NVTX_FN_RANGE(); + auto* impl = page->Impl(); CHECK(this->cuts_->cut_values_.DeviceCanRead()); impl->SetCuts(this->cuts_); - RET_IF_NOT(fi->Read(&impl->n_rows)); - RET_IF_NOT(fi->Read(&impl->is_dense)); - RET_IF_NOT(fi->Read(&impl->row_stride)); - // Read vec + // Read vector Context ctx = Context{}.MakeCUDA(common::CurrentDevice()); - bst_idx_t n{0}; - RET_IF_NOT(fi->Read(&n)); - if (n != 0) { - impl->gidx_buffer = - common::MakeFixedVecWithCudaMalloc(&ctx, n, static_cast(0)); + auto read_vec = [&] { + common::NvtxScopedRange range{common::NvtxEventAttr{"read-vec", common::NvtxRgb{127, 255, 0}}}; + bst_idx_t n{0}; + RET_IF_NOT(fi->Read(&n)); + if (n == 0) { + return true; + } + impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc(&ctx, n); RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes())); - } + return true; + }; + RET_IF_NOT(read_vec()); + + RET_IF_NOT(fi->Read(&impl->n_rows)); + RET_IF_NOT(fi->Read(&impl->is_dense)); + RET_IF_NOT(fi->Read(&impl->row_stride)); RET_IF_NOT(fi->Read(&impl->base_rowid)); dh::DefaultStream().Sync(); @@ -108,19 +122,27 @@ template [[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, EllpackHostCacheStream* fo) const { + xgboost_NVTX_FN_RANGE(); + bst_idx_t bytes{0}; auto* impl = page.Impl(); - bytes += fo->Write(impl->n_rows); - bytes += fo->Write(impl->is_dense); - bytes += fo->Write(impl->row_stride); // Write vector - bst_idx_t n = impl->gidx_buffer.size(); - bytes += fo->Write(n); + auto write_vec = [&] { + common::NvtxScopedRange range{common::NvtxEventAttr{"write-vec", common::NvtxRgb{127, 255, 0}}}; + bst_idx_t n = impl->gidx_buffer.size(); + bytes += fo->Write(n); - if (!impl->gidx_buffer.empty()) { - bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()); - } + if (!impl->gidx_buffer.empty()) { + bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()); + } + }; + + write_vec(); + + bytes += fo->Write(impl->n_rows); + bytes += fo->Write(impl->is_dense); + bytes += fo->Write(impl->row_stride); bytes += fo->Write(impl->base_rowid); dh::DefaultStream().Sync(); diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index e2e377e2145d..5bca323af5f4 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -37,4 +37,5 @@ dependencies: - pyspark>=3.4.0 - cloudpickle - pip: + - setuptools - sphinx_rtd_theme