From cc3b56fc37c8a01517ad05a1cee55f9dd768a977 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Aug 2024 11:50:44 +0800 Subject: [PATCH] Cleanup GPU Hist tests. (#10677) * Cleanup GPU Hist tests. - Remove GPU Hist gradient sampling test. The same properties are tested in the gradient sampler test suite. - Move basic histogram tests into the histogram test suite. - Remove the header inclusion of the `updater_gpu_hist.cu` in tests. --- include/xgboost/task.h | 8 +- src/tree/updater_gpu_hist.cu | 6 - .../gpu_hist/test_gradient_based_sampler.cu | 8 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 78 +++++++ tests/cpp/tree/test_gpu_hist.cu | 191 ++---------------- 5 files changed, 106 insertions(+), 185 deletions(-) diff --git a/include/xgboost/task.h b/include/xgboost/task.h index 8f57383ddf32..b51dc6de3cd8 100644 --- a/include/xgboost/task.h +++ b/include/xgboost/task.h @@ -1,12 +1,12 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2024, XGBoost Contributors */ #ifndef XGBOOST_TASK_H_ #define XGBOOST_TASK_H_ #include -#include +#include // for uint8_t namespace xgboost { /*! @@ -23,7 +23,7 @@ namespace xgboost { */ struct ObjInfo { // What kind of problem are we trying to solve - enum Task : uint8_t { + enum Task : std::uint8_t { kRegression = 0, kBinary = 1, kClassification = 2, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 477a7b08a05f..8ff0d61ab25f 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -45,9 +45,7 @@ #include "xgboost/tree_model.h" namespace xgboost::tree { -#if !defined(GTEST_TEST) DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); -#endif // !defined(GTEST_TEST) // Manage memory for a single GPU struct GPUHistMakerDevice { @@ -831,13 +829,11 @@ class GPUHistMaker : public TreeUpdater { std::shared_ptr column_sampler_; }; -#if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") .set_body([](Context const* ctx, ObjInfo const* task) { return new GPUHistMaker(ctx, task); }); -#endif // !defined(GTEST_TEST) class GPUGlobalApproxMaker : public TreeUpdater { public: @@ -960,11 +956,9 @@ class GPUGlobalApproxMaker : public TreeUpdater { common::Monitor monitor_; }; -#if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUApproxMaker, "grow_gpu_approx") .describe("Grow tree with GPU.") .set_body([](Context const* ctx, ObjInfo const* task) { return new GPUGlobalApproxMaker(ctx, task); }); -#endif // !defined(GTEST_TEST) } // namespace xgboost::tree diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 85bea39c5f5c..dcb09ff32315 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -10,9 +10,7 @@ #include "../../filesystem.h" // dmlc::TemporaryDirectory #include "../../helpers.h" -namespace xgboost { -namespace tree { - +namespace xgboost::tree { void VerifySampling(size_t page_size, float subsample, int sampling_method, @@ -151,6 +149,4 @@ TEST(GradientBasedSampler, GradientBasedSamplingExternalMemory) { constexpr bool kFixedSizeSampling = false; VerifySampling(kPageSize, kSubsample, kSamplingMethod, kFixedSizeSampling); } - -}; // namespace tree -}; // namespace xgboost +}; // namespace xgboost::tree diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index fe254421590a..15c8f7def299 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -12,6 +12,7 @@ #include "../../../../src/tree/param.h" // for TrainParam #include "../../categorical_helpers.h" // for OneHotEncodeFeature #include "../../helpers.h" +#include "../../histogram_helpers.h" // for BuildEllpackPage namespace xgboost::tree { TEST(Histogram, DeviceHistogramStorage) { @@ -54,6 +55,83 @@ TEST(Histogram, DeviceHistogramStorage) { EXPECT_ANY_THROW(histogram.AllocateHistograms(&ctx, {kNNodes + 1});); } +std::vector GetHostHistGpair() { + // 24 bins, 3 bins for each feature (column). + std::vector hist_gpair = { + {0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f}, + {2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f}, + {1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f}, + {2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f}, + {1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f}, + {1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f}, + {0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f}, + {2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f} + }; + return hist_gpair; +} + +void TestBuildHist(bool use_shared_memory_histograms) { + int const kNRows = 16, kNCols = 8; + Context ctx{MakeCUDACtx(0)}; + + TrainParam param; + Args args{ + {"max_depth", "6"}, + {"max_leaves", "0"}, + }; + param.Init(args); + + auto page = BuildEllpackPage(&ctx, kNRows, kNCols); + BatchParam batch_param{}; + + xgboost::SimpleLCG gen; + xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); + HostDeviceVector gpair(kNRows); + for (auto& gp : gpair.HostVector()) { + float grad = dist(&gen); + float hess = dist(&gen); + gp = GradientPair{grad, hess}; + } + gpair.SetDevice(ctx.Device()); + + auto row_partitioner = std::make_unique(); + row_partitioner->Reset(&ctx, kNRows, 0); + + auto quantiser = std::make_unique(&ctx, gpair.ConstDeviceSpan(), MetaInfo()); + auto shm_size = use_shared_memory_histograms ? dh::MaxSharedMemoryOptin(ctx.Ordinal()) : 0; + FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64)); + + DeviceHistogramStorage hist; + hist.Init(ctx.Device(), page->Cuts().TotalBins()); + hist.AllocateHistograms(&ctx, {0}); + + DeviceHistogramBuilder builder; + builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), !use_shared_memory_histograms); + builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), + row_partitioner->GetRows(0), hist.GetNodeHistogram(0), *quantiser); + + auto node_histogram = hist.GetNodeHistogram(0); + + std::vector h_result(node_histogram.size()); + dh::CopyDeviceSpanToVector(&h_result, node_histogram); + + std::vector solution = GetHostHistGpair(); + for (size_t i = 0; i < h_result.size(); ++i) { + auto result = quantiser->ToFloatingPoint(h_result[i]); + ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f); + ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f); + } +} + +TEST(Histogram, BuildHistGlobalMem) { + TestBuildHist(false); +} + +TEST(Histogram, BuildHistSharedMem) { + TestBuildHist(true); +} + void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global) { Context ctx = MakeCUDACtx(0); size_t constexpr kBins = 256, kCols = 120, kRows = 16384, kRounds = 16; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 63114926161f..5d1f435de533 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -2,173 +2,26 @@ * Copyright 2017-2024, XGBoost contributors */ #include -#include -#include -#include - -#include -#include - -#include "../../../src/common/common.h" -#include "../../../src/data/ellpack_page.cuh" // for EllpackPageImpl -#include "../../../src/data/ellpack_page.h" // for EllpackPage -#include "../../../src/tree/param.h" // for TrainParam -#include "../../../src/tree/updater_gpu_hist.cu" -#include "../collective/test_worker.h" // for BaseMGPUTest -#include "../filesystem.h" // dmlc::TemporaryDirectory +#include // for Args +#include // for Context +#include // for HostDeviceVector +#include // for Jons +#include // for ObjInfo +#include // for RegTree +#include // for TreeUpdater + +#include // for unique_ptr +#include // for string +#include // for vector + +#include "../../../src/common/random.h" // for GlobalRandom +#include "../../../src/data/ellpack_page.h" // for EllpackPage +#include "../../../src/tree/param.h" // for TrainParam +#include "../collective/test_worker.h" // for BaseMGPUTest +#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" -#include "../histogram_helpers.h" -#include "xgboost/context.h" -#include "xgboost/json.h" namespace xgboost::tree { -std::vector GetHostHistGpair() { - // 24 bins, 3 bins for each feature (column). - std::vector hist_gpair = { - {0.8314f, 0.7147f}, {1.7989f, 3.7312f}, {3.3846f, 3.4598f}, - {2.9277f, 3.5886f}, {1.8429f, 2.4152f}, {1.2443f, 1.9019f}, - {1.6380f, 2.9174f}, {1.5657f, 2.5107f}, {2.8111f, 2.4776f}, - {2.1322f, 3.0651f}, {3.2927f, 3.8540f}, {0.5899f, 0.9866f}, - {1.5185f, 1.6263f}, {2.0686f, 3.1844f}, {2.4278f, 3.0950f}, - {1.5105f, 2.1403f}, {2.6922f, 4.2217f}, {1.8122f, 1.5437f}, - {0.0000f, 0.0000f}, {4.3245f, 5.7955f}, {1.6903f, 2.1103f}, - {2.4012f, 4.4754f}, {3.6136f, 3.4303f}, {0.0000f, 0.0000f} - }; - return hist_gpair; -} - -template -void TestBuildHist(bool use_shared_memory_histograms) { - int const kNRows = 16, kNCols = 8; - Context ctx{MakeCUDACtx(0)}; - - TrainParam param; - Args args{ - {"max_depth", "6"}, - {"max_leaves", "0"}, - }; - param.Init(args); - - auto page = BuildEllpackPage(&ctx, kNRows, kNCols); - BatchParam batch_param{}; - auto cs = std::make_shared(0); - GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols, - batch_param, MetaInfo()); - xgboost::SimpleLCG gen; - xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); - HostDeviceVector gpair(kNRows); - for (auto& gp : gpair.HostVector()) { - float grad = dist(&gen); - float hess = dist(&gen); - gp = GradientPair{grad, hess}; - } - gpair.SetDevice(ctx.Device()); - - maker.row_partitioner = std::make_unique(); - maker.row_partitioner->Reset(&ctx, kNRows, 0); - - maker.hist.Init(ctx.Device(), page->Cuts().TotalBins()); - maker.hist.AllocateHistograms(&ctx, {0}); - - maker.gpair = gpair.DeviceSpan(); - maker.quantiser = std::make_unique(&ctx, maker.gpair, MetaInfo()); - maker.page = page.get(); - - maker.InitFeatureGroupsOnce(); - - DeviceHistogramBuilder builder; - builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()), - !use_shared_memory_histograms); - builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()), - maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), - maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0), - *maker.quantiser); - - DeviceHistogramStorage<>& d_hist = maker.hist; - - auto node_histogram = d_hist.GetNodeHistogram(0); - // d_hist.data stored in float, not gradient pair - thrust::host_vector h_result (node_histogram.size()); - dh::safe_cuda(cudaMemcpy(h_result.data(), node_histogram.data(), node_histogram.size_bytes(), - cudaMemcpyDeviceToHost)); - - std::vector solution = GetHostHistGpair(); - for (size_t i = 0; i < h_result.size(); ++i) { - auto result = maker.quantiser->ToFloatingPoint(h_result[i]); - ASSERT_NEAR(result.GetGrad(), solution[i].GetGrad(), 0.01f); - ASSERT_NEAR(result.GetHess(), solution[i].GetHess(), 0.01f); - } -} - -TEST(GpuHist, BuildHistGlobalMem) { - TestBuildHist(false); -} - -TEST(GpuHist, BuildHistSharedMem) { - TestBuildHist(true); -} - -std::shared_ptr GetHostCutMatrix () { - auto cmat = std::make_shared(); - cmat->SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); - cmat->SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); - // 24 cut fields, 3 cut fields for each feature (column). - // Each row of the cut represents the cuts for a data column. - cmat->SetValues({0.30f, 0.67f, 1.64f, - 0.32f, 0.77f, 1.95f, - 0.29f, 0.70f, 1.80f, - 0.32f, 0.75f, 1.85f, - 0.18f, 0.59f, 1.69f, - 0.25f, 0.74f, 2.00f, - 0.26f, 0.74f, 1.98f, - 0.26f, 0.71f, 1.83f}); - return cmat; -} - -void TestHistogramIndexImpl() { - // Test if the compressed histogram index matches when using a sparse - // dmatrix with and without using external memory - - int constexpr kNRows = 1000, kNCols = 10; - - // Build 2 matrices and build a histogram maker with that - Context ctx(MakeCUDACtx(0)); - ObjInfo task{ObjInfo::kRegression}; - tree::GPUHistMaker hist_maker{&ctx, &task}, hist_maker_ext{&ctx, &task}; - std::unique_ptr hist_maker_dmat( - CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); - - dmlc::TemporaryDirectory tempdir; - std::unique_ptr hist_maker_ext_dmat( - CreateSparsePageDMatrixWithRC(kNRows, kNCols, 128UL, true, tempdir)); - - Args training_params = {{"max_depth", "10"}, {"max_leaves", "0"}}; - TrainParam param; - param.UpdateAllowUnknown(training_params); - - hist_maker.Configure(training_params); - hist_maker.InitDataOnce(¶m, hist_maker_dmat.get()); - hist_maker_ext.Configure(training_params); - hist_maker_ext.InitDataOnce(¶m, hist_maker_ext_dmat.get()); - - // Extract the device maker from the histogram makers and from that its compressed - // histogram index - const auto &maker = hist_maker.maker; - auto grad = GenerateRandomGradients(kNRows); - grad.SetDevice(DeviceOrd::CUDA(0)); - maker->Reset(&grad, hist_maker_dmat.get(), kNCols); - - const auto &maker_ext = hist_maker_ext.maker; - maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols); - - ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); - ASSERT_EQ(maker->page->gidx_buffer.size(), maker_ext->page->gidx_buffer.size()); -} - -TEST(GpuHist, TestHistogramIndex) { - TestHistogramIndexImpl(); -} - void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix* dmat, size_t gpu_page_size, RegTree* tree, HostDeviceVector* preds, float subsample = 1.0f, const std::string& sampling_method = "uniform", @@ -200,14 +53,14 @@ void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix param.UpdateAllowUnknown(args); ObjInfo task{ObjInfo::kRegression}; - tree::GPUHistMaker hist_maker{ctx, &task}; - hist_maker.Configure(Args{}); + std::unique_ptr hist_maker{TreeUpdater::Create("grow_gpu_hist", ctx, &task)}; + hist_maker->Configure(Args{}); std::vector> position(1); - hist_maker.Update(¶m, gpair, dmat, common::Span>{position}, - {tree}); + hist_maker->Update(¶m, gpair, dmat, common::Span>{position}, + {tree}); auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1); - hist_maker.UpdatePredictionCache(dmat, cache); + hist_maker->UpdatePredictionCache(dmat, cache); } TEST(GpuHist, UniformSampling) {