diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 10329f87b074..87d3be1fe34b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -541,9 +541,12 @@ class DMatrix { [[nodiscard]] bool PageExists() const; /** - * @return Whether the data columns single column block. + * @return Whether the contains a single batch. + * + * The naming is legacy. */ - [[nodiscard]] virtual bool SingleColBlock() const = 0; + [[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; } + [[nodiscard]] virtual std::int32_t NumBatches() const { return 1; } virtual ~DMatrix(); diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 6cd0cd76a47a..3adc39e73777 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator { } // namespace detail template -using TypedDiscard = - std::conditional_t(), detail::TypedDiscardCTK114, - detail::TypedDiscard>; +using TypedDiscard = std::conditional_t(), detail::TypedDiscardCTK114, + detail::TypedDiscard>; template ::index_type> -xgboost::common::Span ToSpan( - VectorT &vec, - IndexT offset = 0, - IndexT size = std::numeric_limits::max()) { + typename IndexT = typename xgboost::common::Span::index_type> +xgboost::common::Span ToSpan(VectorT &vec, IndexT offset = 0, + IndexT size = std::numeric_limits::max()) { size = size == std::numeric_limits::max() ? vec.size() : size; CHECK_LE(offset + size, vec.size()); - return {vec.data().get() + offset, size}; + return {thrust::raw_pointer_cast(vec.data()) + offset, size}; } template -xgboost::common::Span ToSpan(thrust::device_vector& vec, - size_t offset, size_t size) { +xgboost::common::Span ToSpan(thrust::device_vector &vec, size_t offset, size_t size) { return ToSpan(vec, offset, size); } @@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT // Changing this has effect on prediction return, where we need to pass the pointer to // third-party libraries like cuPy -inline CUDAStreamView DefaultStream() { -#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM - return CUDAStreamView{cudaStreamPerThread}; -#else - return CUDAStreamView{cudaStreamLegacy}; -#endif -} +inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; } class CUDAStream { cudaStream_t stream_; diff --git a/src/data/extmem_quantile_dmatrix.cc b/src/data/extmem_quantile_dmatrix.cc index 0d17fcf55ae8..96e88a55a0e3 100644 --- a/src/data/extmem_quantile_dmatrix.cc +++ b/src/data/extmem_quantile_dmatrix.cc @@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU( cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info); ext_info.SetInfo(ctx, &this->info_); + this->n_batches_ = ext_info.n_batches; + /** * Generate quantiles */ diff --git a/src/data/extmem_quantile_dmatrix.h b/src/data/extmem_quantile_dmatrix.h index d3b9f5a7820a..33a80f5cda92 100644 --- a/src/data/extmem_quantile_dmatrix.h +++ b/src/data/extmem_quantile_dmatrix.h @@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { std::string cache, bst_bin_t max_bin, bool on_host); ~ExtMemQuantileDMatrix() override; - [[nodiscard]] bool SingleColBlock() const override { return false; } + [[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; } private: void InitFromCPU( @@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix { std::string cache_prefix_; bool on_host_; BatchParam batch_; + bst_idx_t n_batches_{0}; using EllpackDiskPtr = std::shared_ptr; using EllpackHostPtr = std::shared_ptr; diff --git a/src/data/iterative_dmatrix.h b/src/data/iterative_dmatrix.h index 33350d372ac2..acec4708e634 100644 --- a/src/data/iterative_dmatrix.h +++ b/src/data/iterative_dmatrix.h @@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix { BatchSet GetEllpackBatches(Context const *ctx, const BatchParam ¶m) override; BatchSet GetExtBatches(Context const *ctx, BatchParam const ¶m) override; - - bool SingleColBlock() const override { return true; } }; } // namespace data } // namespace xgboost diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 8e62802c38aa..221e13fb32fc 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix { MetaInfo const& Info() const override { return info_; } Context const* Ctx() const override { return &ctx_; } - bool SingleColBlock() const override { return false; } bool EllpackExists() const override { return false; } bool GHistIndexExists() const override { return false; } bool SparsePageExists() const override { return false; } diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 5b5bb2bfb2ba..ac757591cdb2 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix { const MetaInfo& Info() const override; Context const* Ctx() const override { return &fmat_ctx_; } - bool SingleColBlock() const override { return true; } DMatrix* Slice(common::Span ridxs) override; DMatrix* SliceCol(int num_slices, int slice_id) override; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 245ec0e4b5dc..f40c16f72488 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix { [[nodiscard]] MetaInfo &Info() override; [[nodiscard]] const MetaInfo &Info() const override; [[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; } - // The only DMatrix implementation that returns false. - [[nodiscard]] bool SingleColBlock() const override { return false; } + [[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; } DMatrix *Slice(common::Span) override { LOG(FATAL) << "Slicing DMatrix is not supported for external memory."; return nullptr; diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index 363c46f2d413..6247d66b37fc 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -3,10 +3,10 @@ */ #include "sparse_page_source.h" -#include // for exists -#include // for string #include // for remove +#include // for exists #include // for partial_sum +#include // for string namespace xgboost::data { void Cache::Commit() { @@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) { << "; you may want to remove it manually"; } } + +#if !defined(XGBOOST_USE_CUDA) +void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; } +#endif } // namespace xgboost::data diff --git a/src/data/sparse_page_source.cu b/src/data/sparse_page_source.cu index 125b7f261616..84d6197e689c 100644 --- a/src/data/sparse_page_source.cu +++ b/src/data/sparse_page_source.cu @@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) { cuda_impl::Dispatch(proxy, [&](auto const &value) { CopyToSparsePage(value, device, missing, page); }); } + +void InitNewThread::operator()() const { + *GlobalConfigThreadLocalStore::Get() = config; + // For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the + // stream when creating a new thread in the thread pool. While for CUDA 11.8, this + // action might cause an insufficient driver version error for some reason. Lastly, it + // should work with CUDA 12.5 without any action being taken. + + // dh::DefaultStream().Sync(); +} } // namespace xgboost::data diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index e750f00fccdd..ca04e969fddf 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -210,6 +210,12 @@ class DefaultFormatPolicy { } }; +struct InitNewThread { + GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get(); + + void operator()() const; +}; + /** * @brief Base class for all page sources. Handles fetching, writing, and iteration. * @@ -330,10 +336,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl, public FormatStreamPol public: SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches, std::shared_ptr cache) - : workers_{std::max(2, std::min(nthreads, 16)), - [config = *GlobalConfigThreadLocalStore::Get()] { - *GlobalConfigThreadLocalStore::Get() = config; - }}, + : workers_{std::max(2, std::min(nthreads, 16)), InitNewThread{}}, missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index f9e34790d4a3..49e43d8340e0 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -63,26 +63,27 @@ TEST(SparsePage, PushCSC) { } TEST(SparsePage, PushCSCAfterTranspose) { - size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries); + bst_idx_t constexpr kRows = 1024, kCols = 21; + + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); const int ncols = dmat->Info().num_col_; - SparsePage page; // Consolidated sparse page - for (const auto &batch : dmat->GetBatches()) { + SparsePage page; // Consolidated sparse page + for (const auto& batch : dmat->GetBatches()) { // Transpose each batch and push SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest()); page.PushCSC(tmp); } // Make sure that the final sparse page has the right number of entries - ASSERT_EQ(kEntries, page.data.Size()); + ASSERT_EQ(kRows * kCols, page.data.Size()); page.SortRows(AllThreadsForTest()); auto v = page.GetView(); for (size_t i = 0; i < v.Size(); ++i) { auto column = v[i]; for (size_t j = 1; j < column.size(); ++j) { - ASSERT_GE(column[j].fvalue, column[j-1].fvalue); + ASSERT_GE(column[j].fvalue, column[j - 1].fvalue); } } } diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 8aab51b7202e..f3957a002279 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -140,13 +140,11 @@ struct ReadRowFunction { TEST(EllpackPage, Copy) { constexpr size_t kRows = 1024; constexpr size_t kCols = 16; - constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); - Context ctx{MakeCUDACtx(0)}; + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); + auto ctx = MakeCUDACtx(0); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); @@ -187,14 +185,12 @@ TEST(EllpackPage, Copy) { TEST(EllpackPage, Compact) { constexpr size_t kRows = 16; constexpr size_t kCols = 2; - constexpr size_t kPageSize = 1; constexpr size_t kCompactedRows = 8; // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr dmat( - CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); - Context ctx{MakeCUDACtx(0)}; + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true); + auto ctx = MakeCUDACtx(0); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; auto page = (*dmat->GetBatches(&ctx, param).begin()).Impl(); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 3aeb42abce2b..b52d49176ef6 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -214,15 +214,15 @@ TEST(SparsePageDMatrix, MetaInfo) { } TEST(SparsePageDMatrix, RowAccess) { - std::unique_ptr dmat = xgboost::CreateSparsePageDMatrix(24); + auto dmat = RandomDataGenerator{12, 6, 0.8f}.Batches(2).GenerateSparsePageDMatrix("temp", false); // Test the data read into the first row auto &batch = *dmat->GetBatches().begin(); auto page = batch.GetView(); auto first_row = page[0]; - ASSERT_EQ(first_row.size(), 3ul); - EXPECT_EQ(first_row[2].index, 2u); - EXPECT_NEAR(first_row[2].fvalue, 0.986566, 1e-4); + ASSERT_EQ(first_row.size(), 1ul); + EXPECT_EQ(first_row[0].index, 5u); + EXPECT_NEAR(first_row[0].fvalue, 0.1805125, 1e-4); } TEST(SparsePageDMatrix, ColAccess) { @@ -268,11 +268,10 @@ TEST(SparsePageDMatrix, ColAccess) { } TEST(SparsePageDMatrix, ThreadSafetyException) { - size_t constexpr kEntriesPerCol = 3; - size_t constexpr kEntries = 64 * kEntriesPerCol * 2; Context ctx; - std::unique_ptr dmat = xgboost::CreateSparsePageDMatrix(kEntries); + auto dmat = + RandomDataGenerator{4096, 12, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true); int threads = 1000; @@ -304,10 +303,9 @@ TEST(SparsePageDMatrix, ThreadSafetyException) { // Multi-batches access TEST(SparsePageDMatrix, ColAccessBatches) { - size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; // Create multiple sparse pages - std::unique_ptr dmat{xgboost::CreateSparsePageDMatrix(kEntries)}; + auto dmat = + RandomDataGenerator{1024, 32, 0.4f}.Batches(3).GenerateSparsePageDMatrix("temp", true); ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest()); Context ctx; for (auto const &page : dmat->GetBatches(&ctx)) { diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 046c4eed4d80..f74ca28eb85e 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -115,13 +115,10 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) { } TEST(SparsePageDMatrix, MultipleEllpackPages) { - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; dmlc::TemporaryDirectory tmpdir; - std::string filename = tmpdir.path + "/big.libsvm"; - size_t constexpr kPageSize = 64, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries, filename); + auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true); // Loop over the batches and count the records std::int64_t batch_count = 0; @@ -135,15 +132,13 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) { EXPECT_EQ(row_count, dmat->Info().num_row_); auto path = - data::MakeId(filename, - dynamic_cast(dmat.get())) + - ".ellpack.page"; + data::MakeId("tmep", dynamic_cast(dmat.get())) + ".ellpack.page"; } TEST(SparsePageDMatrix, RetainEllpackPage) { - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()}; - auto m = CreateSparsePageDMatrix(10000); + auto m = RandomDataGenerator{2048, 4, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true); auto batches = m->GetBatches(&ctx, param); auto begin = batches.begin(); @@ -278,20 +273,19 @@ struct ReadRowFunction { }; TEST(SparsePageDMatrix, MultipleEllpackPageContent) { - constexpr size_t kRows = 6; + constexpr size_t kRows = 16; constexpr size_t kCols = 2; constexpr int kMaxBins = 256; - constexpr size_t kPageSize = 1; // Create an in-memory DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true); // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto dmat_ext = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true); - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()}; auto impl = (*dmat->GetBatches(&ctx, param).begin()).Impl(); EXPECT_EQ(impl->base_rowid, 0); @@ -325,17 +319,16 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) { constexpr size_t kRows = 1024; constexpr size_t kCols = 16; constexpr int kMaxBins = 256; - constexpr size_t kPageSize = 4096; // Create an in-memory DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true); // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr - dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir)); + auto dmat_ext = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true); - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()}; size_t current_row = 0; diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index dcb89b97189c..8a5383ad4d34 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -715,7 +715,7 @@ TEST(GBTree, InplacePredictionError) { p_fmat = rng.GenerateQuantileDMatrix(true); } else { #if defined(XGBOOST_USE_CUDA) - p_fmat = rng.GenerateDeviceDMatrix(true); + p_fmat = rng.Device(ctx->Device()).GenerateQuantileDMatrix(true); #else CHECK(p_fmat); #endif // defined(XGBOOST_USE_CUDA) diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 05f84316467c..ae5698d2cc6e 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -13,7 +13,6 @@ #include #include // for numeric_limits -#include #include "../../src/collective/communicator-inl.h" // for GetRank #include "../../src/data/adapter.h" @@ -21,8 +20,6 @@ #include "../../src/data/simple_dmatrix.h" #include "../../src/data/sparse_page_dmatrix.h" #include "../../src/gbm/gbtree_model.h" -#include "../../src/tree/param.h" // for TrainParam -#include "filesystem.h" // dmlc::TemporaryDirectory #include "xgboost/c_api.h" #include "xgboost/predictor.h" @@ -456,6 +453,7 @@ void RandomDataGenerator::GenerateCSR( } EXPECT_EQ(batch_count, n_batches_); + EXPECT_EQ(dmat->NumBatches(), n_batches_); EXPECT_EQ(row_count, dmat->Info().num_row_); if (with_label) { @@ -503,13 +501,24 @@ void RandomDataGenerator::GenerateCSR( } std::shared_ptr RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) { - NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; - auto m = std::make_shared( - &iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits::quiet_NaN(), 0, bins_); + std::shared_ptr p_fmat; + + if (this->device_.IsCPU()) { + NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; + p_fmat = + std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, bins_); + } else { + CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; + p_fmat = + std::make_shared(&iter, iter.Proxy(), nullptr, Reset, Next, + std::numeric_limits::quiet_NaN(), 0, bins_); + } + if (with_label) { - this->GenerateLabels(m); + this->GenerateLabels(p_fmat); } - return m; + return p_fmat; } #if !defined(XGBOOST_USE_CUDA) @@ -551,125 +560,6 @@ std::shared_ptr GetDMatrixFromData(const std::vector& x, std::si return p_fmat; } -std::unique_ptr CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features, - size_t n_batches, std::string prefix) { - CHECK_GE(n_samples, n_batches); - NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches); - - std::unique_ptr dmat{DMatrix::Create( - static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), omp_get_max_threads(), prefix, false)}; - - auto row_page_path = - data::MakeId(prefix, dynamic_cast(dmat.get())) + ".row.page"; - EXPECT_TRUE(FileExists(row_page_path)) << row_page_path; - - // Loop over the batches and count the number of pages - int64_t batch_count = 0; - int64_t row_count = 0; - for (const auto& batch : dmat->GetBatches()) { - batch_count++; - row_count += batch.Size(); - } - - EXPECT_GE(batch_count, n_batches); - EXPECT_EQ(row_count, dmat->Info().num_row_); - return dmat; -} - -std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, - std::string prefix) { - size_t n_columns = 3; - size_t n_rows = n_entries / n_columns; - NumpyArrayIterForTest iter(0, n_rows, n_columns, 2); - - std::unique_ptr dmat{ - DMatrix::Create(static_cast(&iter), iter.Proxy(), Reset, Next, - std::numeric_limits::quiet_NaN(), 0, prefix, false)}; - auto row_page_path = - data::MakeId(prefix, - dynamic_cast(dmat.get())) + - ".row.page"; - EXPECT_TRUE(FileExists(row_page_path)) << row_page_path; - - // Loop over the batches and count the records - int64_t batch_count = 0; - int64_t row_count = 0; - for (const auto &batch : dmat->GetBatches()) { - batch_count++; - row_count += batch.Size(); - } - EXPECT_GE(batch_count, 2); - EXPECT_EQ(row_count, dmat->Info().num_row_); - return dmat; -} - -std::unique_ptr CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols, - size_t page_size, bool deterministic, - const dmlc::TemporaryDirectory& tempdir) { - if (!n_rows || !n_cols) { - return nullptr; - } - - // Create the svm file in a temp dir - const std::string tmp_file = tempdir.path + "/big.libsvm"; - - std::ofstream fo(tmp_file.c_str()); - size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1; - int64_t rem_cols = n_cols; - size_t col_idx = 0; - - // Random feature id generator - std::random_device rdev; - std::unique_ptr gen; - if (deterministic) { - // Seed it with a constant value for this configuration - without getting too fancy - // like ordered pairing functions and its likes to make it truely unique - gen.reset(new std::mt19937(n_rows * n_cols)); - } else { - gen.reset(new std::mt19937(rdev())); - } - std::uniform_int_distribution label(0, 1); - std::uniform_int_distribution dis(1, n_cols); - - for (size_t i = 0; i < n_rows; ++i) { - // Make sure that all cols are slotted in the first few rows; randomly distribute the - // rest - std::stringstream row_data; - size_t j = 0; - if (rem_cols > 0) { - for (; j < std::min(static_cast(rem_cols), cols_per_row); ++j) { - row_data << label(*gen) << " " << (col_idx + j) << ":" - << (col_idx + j + 1) * 10 * i; - } - rem_cols -= cols_per_row; - } else { - // Take some random number of colums in [1, n_cols] and slot them here - std::vector random_columns; - size_t ncols = dis(*gen); - for (; j < ncols; ++j) { - size_t fid = (col_idx + j) % n_cols; - random_columns.push_back(fid); - } - std::sort(random_columns.begin(), random_columns.end()); - for (auto fid : random_columns) { - row_data << label(*gen) << " " << fid << ":" << (fid + 1) * 10 * i; - } - } - col_idx += j; - - fo << row_data.str() << "\n"; - } - fo.close(); - - std::string uri = tmp_file + "?format=libsvm"; - if (page_size > 0) { - uri += "#" + tmp_file + ".cache"; - } - std::unique_ptr dmat(DMatrix::Load(uri)); - return dmat; -} - std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, size_t kRows, size_t kCols, LearnerModelParam const* learner_model_param, diff --git a/tests/cpp/helpers.cu b/tests/cpp/helpers.cu index f756289538ab..ef6beb33687b 100644 --- a/tests/cpp/helpers.cu +++ b/tests/cpp/helpers.cu @@ -3,12 +3,9 @@ */ #include -#include "../../src/data/device_adapter.cuh" -#include "../../src/data/iterative_dmatrix.h" #include "helpers.h" namespace xgboost { - CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches) : ArrayIterForTest{sparsity, rows, cols, batches} { @@ -26,14 +23,4 @@ int CudaArrayIterForTest::Next() { iter_++; return 1; } - -std::shared_ptr RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) { - CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; - auto m = std::make_shared( - &iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits::quiet_NaN(), 0, bins_); - if (with_label) { - this->GenerateLabels(m); - } - return m; -} } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 50ae8bce076e..a8d5f370f3a2 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -324,9 +324,6 @@ class RandomDataGenerator { [[nodiscard]] std::shared_ptr GenerateExtMemQuantileDMatrix(std::string prefix, bool with_label) const; -#if defined(XGBOOST_USE_CUDA) - std::shared_ptr GenerateDeviceDMatrix(bool with_label); -#endif std::shared_ptr GenerateQuantileDMatrix(bool with_label); }; @@ -350,45 +347,6 @@ inline std::vector GenerateRandomCategoricalSingleColumn(int n, size_t nu std::shared_ptr GetDMatrixFromData(const std::vector& x, std::size_t num_rows, bst_feature_t num_columns); -/** - * \brief Create Sparse Page using data iterator. - * - * \param n_samples Total number of rows for all batches combined. - * \param n_features Number of features - * \param n_batches Number of batches - * \param prefix Cache prefix, can be used for specifying file path. - * - * \return A Sparse DMatrix with n_batches. - */ -std::unique_ptr CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features, - size_t n_batches, std::string prefix = "cache"); - -/** - * Deprecated, stop using it - */ -std::unique_ptr CreateSparsePageDMatrix(size_t n_entries, std::string prefix = "cache"); - -/** - * Deprecated, stop using it - * - * \brief Creates dmatrix with some records, each record containing random number of - * features in [1, n_cols] - * - * \param n_rows Number of records to create. - * \param n_cols Max number of features within that record. - * \param page_size Sparse page size for the pages within the dmatrix. If page size is 0 - * then the entire dmatrix is resident in memory; else, multiple sparse pages - * of page size are created and backed to disk, which would have to be - * streamed in at point of use. - * \param deterministic The content inside the dmatrix is constant for this configuration, if true; - * else, the content changes every time this method is invoked - * - * \return The new dmatrix. - */ -std::unique_ptr CreateSparsePageDMatrixWithRC( - size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, - const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory()); - std::unique_ptr CreateTrainedGBM(std::string name, Args kwargs, size_t kRows, size_t kCols, LearnerModelParam const* learner_model_param, diff --git a/tests/cpp/plugin/test_sycl_predictor.cc b/tests/cpp/plugin/test_sycl_predictor.cc index 7bd788a3b071..a7ec51594e08 100755 --- a/tests/cpp/plugin/test_sycl_predictor.cc +++ b/tests/cpp/plugin/test_sycl_predictor.cc @@ -36,9 +36,10 @@ TEST(SyclPredictor, ExternalMemory) { Context ctx; ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - size_t constexpr kPageSize = 64, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries); + bst_idx_t constexpr kRows{64}; + bst_feature_t constexpr kCols{12}; + auto dmat = + RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true); TestBasic(dmat.get(), &ctx); } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index c0d2c8e285af..ee28adb155c9 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -10,12 +10,10 @@ #include "../../../src/gbm/gbtree.h" #include "../../../src/gbm/gbtree_model.h" #include "../collective/test_worker.h" // for TestDistributedGlobal -#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" #include "test_predictor.h" namespace xgboost { - TEST(CpuPredictor, Basic) { Context ctx; size_t constexpr kRows = 5; @@ -56,9 +54,10 @@ TEST(CpuPredictor, IterationRangeColmnSplit) { TEST(CpuPredictor, ExternalMemory) { Context ctx; - size_t constexpr kPageSize = 64, kEntriesPerCol = 3; - size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries); + bst_idx_t constexpr kRows{64}; + bst_feature_t constexpr kCols{12}; + auto dmat = + RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true); TestBasic(dmat.get(), &ctx); } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 01de15fe8bc8..5e3021fd71e1 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -123,8 +123,8 @@ TEST(GPUPredictor, EllpackBasic) { size_t rows = bins * 16; auto p_m = RandomDataGenerator{rows, kCols, 0.0} .Bins(bins) - .Device(DeviceOrd::CUDA(0)) - .GenerateDeviceDMatrix(false); + .Device(ctx.Device()) + .GenerateQuantileDMatrix(false); ASSERT_FALSE(p_m->PageExists()); TestPredictionFromGradientIndex(&ctx, rows, kCols, p_m); TestPredictionFromGradientIndex(&ctx, bins, kCols, p_m); @@ -137,7 +137,7 @@ TEST(GPUPredictor, EllpackTraining) { auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0} .Bins(kBins) .Device(ctx.Device()) - .GenerateDeviceDMatrix(false); + .GenerateQuantileDMatrix(false); HostDeviceVector storage(kRows * kCols); auto columnar = RandomDataGenerator{kRows, kCols, 0.0}.Device(ctx.Device()).GenerateArrayInterface(&storage); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index be11a2a765b7..a6f3eacecbc5 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -117,24 +117,15 @@ TEST(Learner, CheckGroup) { EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); } -TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT - // Create sufficiently large data to make two row pages - dmlc::TemporaryDirectory tempdir; - const std::string tmp_file = tempdir.path + "/big.libsvm"; - CreateBigTestData(tmp_file, 50000); - std::shared_ptr dmat( - xgboost::DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache")); - EXPECT_FALSE(dmat->SingleColBlock()); - size_t num_row = dmat->Info().num_row_; - std::vector labels(num_row); - for (size_t i = 0; i < num_row; ++i) { - labels[i] = i % 2; - } - dmat->SetInfo("label", Make1dInterfaceTest(labels.data(), num_row)); - std::vector> mat{dmat}; +TEST(Learner, CheckMultiBatch) { + auto p_fmat = + RandomDataGenerator{512, 128, 0.8}.Batches(4).GenerateSparsePageDMatrix("temp", true); + ASSERT_FALSE(p_fmat->SingleColBlock()); + + std::vector> mat{p_fmat}; auto learner = std::unique_ptr(Learner::Create(mat)); learner->SetParams(Args{{"objective", "binary:logistic"}}); - learner->UpdateOneIter(0, dmat); + learner->UpdateOneIter(0, p_fmat); } TEST(Learner, Configuration) { diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index dcb09ff32315..b1e86e2ebbc2 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -7,22 +7,18 @@ #include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh" #include "../../../../src/tree/param.h" #include "../../../../src/tree/param.h" // TrainParam -#include "../../filesystem.h" // dmlc::TemporaryDirectory #include "../../helpers.h" namespace xgboost::tree { -void VerifySampling(size_t page_size, - float subsample, - int sampling_method, - bool fixed_size_sampling = true, - bool check_sum = true) { +void VerifySampling(size_t page_size, float subsample, int sampling_method, + bool fixed_size_sampling = true, bool check_sum = true) { constexpr size_t kRows = 4096; constexpr size_t kCols = 1; - size_t sample_rows = kRows * subsample; + bst_idx_t sample_rows = kRows * subsample; + bst_idx_t n_batches = fixed_size_sampling ? 1 : 4; - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr dmat(CreateSparsePageDMatrix( - kRows, kCols, kRows / (page_size == 0 ? kRows : page_size), tmpdir.path + "/cache")); + auto dmat = RandomDataGenerator{kRows, kCols, 0.0f}.Batches(n_batches).GenerateSparsePageDMatrix( + "temp", true); auto gpair = GenerateRandomGradients(kRows); GradientPair sum_gpair{}; for (const auto& gp : gpair.ConstHostVector()) { @@ -78,14 +74,12 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { constexpr size_t kRows = 2048; constexpr size_t kCols = 1; constexpr float kSubsample = 1.0f; - constexpr size_t kPageSize = 1024; // Create a DMatrix with multiple batches. - dmlc::TemporaryDirectory tmpdir; - std::unique_ptr dmat( - CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache")); + auto dmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); auto gpair = GenerateRandomGradients(kRows); - Context ctx{MakeCUDACtx(0)}; + auto ctx = MakeCUDACtx(0); gpair.SetDevice(ctx.Device()); auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 11bdbd859b1c..5ab0c599ea6a 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -406,7 +406,8 @@ namespace { void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, bool is_approx, bool force_read_by_column) { size_t constexpr kEntries = 1 << 16; - auto m = CreateSparsePageDMatrix(kEntries, "cache"); + auto m = + RandomDataGenerator{kEntries / 8, 8, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); std::vector hess(m->Info().num_row_, 1.0); if (is_approx) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 570ebe76c3da..61f7647579cf 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -17,12 +17,11 @@ #include "../../../src/common/random.h" // for GlobalRandom #include "../../../src/tree/param.h" // for TrainParam #include "../collective/test_worker.h" // for BaseMGPUTest -#include "../filesystem.h" // dmlc::TemporaryDirectory #include "../helpers.h" namespace xgboost::tree { namespace { -void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix* dmat, bool is_ext, +void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix* dmat, RegTree* tree, HostDeviceVector* preds, float subsample, const std::string& sampling_method, bst_bin_t max_bin) { Args args{ @@ -45,7 +44,7 @@ void UpdateTree(Context const* ctx, linalg::Matrix* gpair, DMatrix hist_maker->Update(¶m, gpair, dmat, common::Span>{position}, {tree}); auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1); - if (subsample < 1.0 && is_ext) { + if (subsample < 1.0 && !dmat->SingleColBlock()) { ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache)); } else { ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache)); @@ -58,22 +57,23 @@ TEST(GpuHist, UniformSampling) { constexpr size_t kCols = 2; constexpr float kSubsample = 0.9999; common::GlobalRandom().seed(1994); + auto ctx = MakeCUDACtx(0); // Create an in-memory DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true); + ASSERT_TRUE(p_fmat->SingleColBlock()); - linalg::Matrix gpair({kRows}, Context{}.MakeCUDA().Device()); + linalg::Matrix gpair({kRows}, ctx.Device()); gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; - HostDeviceVector preds(kRows, 0.0, DeviceOrd::CUDA(0)); - Context ctx(MakeCUDACtx(0)); - UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows); + HostDeviceVector preds(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); // Build another tree using sampling. RegTree tree_sampling; - HostDeviceVector preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, "uniform", + HostDeviceVector preds_sampling(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "uniform", kRows); // Make sure the predictions are the same. @@ -89,23 +89,23 @@ TEST(GpuHist, GradientBasedSampling) { constexpr size_t kCols = 2; constexpr float kSubsample = 0.9999; common::GlobalRandom().seed(1994); + auto ctx = MakeCUDACtx(0); // Create an in-memory DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true)); + auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true); - linalg::Matrix gpair({kRows}, MakeCUDACtx(0).Device()); + linalg::Matrix gpair({kRows}, ctx.Device()); gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; - HostDeviceVector preds(kRows, 0.0, DeviceOrd::CUDA(0)); - Context ctx(MakeCUDACtx(0)); - UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows); + HostDeviceVector preds(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); // Build another tree using sampling. RegTree tree_sampling; - HostDeviceVector preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, + HostDeviceVector preds_sampling(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "gradient_based", kRows); // Make sure the predictions are the same. @@ -119,29 +119,29 @@ TEST(GpuHist, GradientBasedSampling) { TEST(GpuHist, ExternalMemory) { constexpr size_t kRows = 4096; constexpr size_t kCols = 2; - constexpr size_t kPageSize = 1024; - - dmlc::TemporaryDirectory tmpdir; // Create a DMatrix with multiple batches. - std::unique_ptr dmat_ext( - CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache")); + auto p_fmat_ext = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true); + ASSERT_FALSE(p_fmat_ext->SingleColBlock()); // Create a single batch DMatrix. - std::unique_ptr dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache")); + auto p_fmat = + RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true); + ASSERT_TRUE(p_fmat->SingleColBlock()); - Context ctx(MakeCUDACtx(0)); + auto ctx = MakeCUDACtx(0); linalg::Matrix gpair({kRows}, ctx.Device()); gpair.Data()->Copy(GenerateRandomGradients(kRows)); // Build a tree using the in-memory DMatrix. RegTree tree; - HostDeviceVector preds(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows); + HostDeviceVector preds(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows); // Build another tree using multiple ELLPACK pages. RegTree tree_ext; - HostDeviceVector preds_ext(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, dmat_ext.get(), true, &tree_ext, &preds_ext, 1.0, "uniform", kRows); + HostDeviceVector preds_ext(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows); // Make sure the predictions are the same. auto preds_h = preds.ConstHostVector(); @@ -157,20 +157,21 @@ TEST(GpuHist, ExternalMemoryWithSampling) { const std::string kSamplingMethod = "gradient_based"; common::GlobalRandom().seed(0); - dmlc::TemporaryDirectory tmpdir; - Context ctx(MakeCUDACtx(0)); + auto ctx = MakeCUDACtx(0); // Create a single batch DMatrix. auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f} .Device(ctx.Device()) .Batches(1) .GenerateSparsePageDMatrix("temp", true); + ASSERT_TRUE(p_fmat->SingleColBlock()); // Create a DMatrix with multiple batches. auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f} .Device(ctx.Device()) .Batches(4) .GenerateSparsePageDMatrix("temp", true); + ASSERT_FALSE(p_fmat_ext->SingleColBlock()); linalg::Matrix gpair({kRows}, ctx.Device()); gpair.Data()->Copy(GenerateRandomGradients(kRows)); @@ -179,26 +180,25 @@ TEST(GpuHist, ExternalMemoryWithSampling) { auto rng = common::GlobalRandom(); RegTree tree; - HostDeviceVector preds(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, p_fmat.get(), true, &tree, &preds, kSubsample, kSamplingMethod, kRows); + HostDeviceVector preds(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows); // Build another tree using multiple ELLPACK pages. common::GlobalRandom() = rng; RegTree tree_ext; - HostDeviceVector preds_ext(kRows, 0.0, DeviceOrd::CUDA(0)); - UpdateTree(&ctx, &gpair, p_fmat_ext.get(), true, &tree_ext, &preds_ext, kSubsample, - kSamplingMethod, kRows); + HostDeviceVector preds_ext(kRows, 0.0, ctx.Device()); + UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, kSubsample, kSamplingMethod, + kRows); - // Make sure the predictions are the same. - auto preds_h = preds.ConstHostVector(); - auto preds_ext_h = preds_ext.ConstHostVector(); - for (size_t i = 0; i < kRows; i++) { - ASSERT_NEAR(preds_h[i], preds_ext_h[i], 1e-3); - } + Json jtree{Object{}}; + Json jtree_ext{Object{}}; + tree.SaveModel(&jtree); + tree_ext.SaveModel(&jtree_ext); + ASSERT_EQ(jtree, jtree_ext); } TEST(GpuHist, ConfigIO) { - Context ctx(MakeCUDACtx(0)); + auto ctx = MakeCUDACtx(0); ObjInfo task{ObjInfo::kRegression}; std::unique_ptr updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)}; updater->Configure(Args{});