Skip to content

Commit

Permalink
[EM] Enable access to the number of batches. (#10691)
Browse files Browse the repository at this point in the history
- Expose `NumBatches` in `DMatrix`.
- Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~.
- Purge old external memory data generation code.
  • Loading branch information
trivialfis authored Aug 16, 2024
1 parent 033a666 commit 8d7fe26
Show file tree
Hide file tree
Showing 26 changed files with 168 additions and 351 deletions.
7 changes: 5 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,12 @@ class DMatrix {
[[nodiscard]] bool PageExists() const;

/**
* @return Whether the data columns single column block.
* @return Whether the contains a single batch.
*
* The naming is legacy.
*/
[[nodiscard]] virtual bool SingleColBlock() const = 0;
[[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; }
[[nodiscard]] virtual std::int32_t NumBatches() const { return 1; }

virtual ~DMatrix();

Expand Down
26 changes: 8 additions & 18 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator<T> {
} // namespace detail

template <typename T>
using TypedDiscard =
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
using TypedDiscard = std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;

template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
VectorT &vec,
IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(VectorT &vec, IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
CHECK_LE(offset + size, vec.size());
return {vec.data().get() + offset, size};
return {thrust::raw_pointer_cast(vec.data()) + offset, size};
}

template <typename T>
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
size_t offset, size_t size) {
xgboost::common::Span<T> ToSpan(thrust::device_vector<T> &vec, size_t offset, size_t size) {
return ToSpan(vec, offset, size);
}

Expand Down Expand Up @@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT

// Changing this has effect on prediction return, where we need to pass the pointer to
// third-party libraries like cuPy
inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread};
#else
return CUDAStreamView{cudaStreamLegacy};
#endif
}
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }

class CUDAStream {
cudaStream_t stream_;
Expand Down
2 changes: 2 additions & 0 deletions src/data/extmem_quantile_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
ext_info.SetInfo(ctx, &this->info_);

this->n_batches_ = ext_info.n_batches;

/**
* Generate quantiles
*/
Expand Down
3 changes: 2 additions & 1 deletion src/data/extmem_quantile_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache, bst_bin_t max_bin, bool on_host);
~ExtMemQuantileDMatrix() override;

[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }

private:
void InitFromCPU(
Expand Down Expand Up @@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache_prefix_;
bool on_host_;
BatchParam batch_;
bst_idx_t n_batches_{0};

using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;
Expand Down
2 changes: 0 additions & 2 deletions src/data/iterative_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix {

BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const &param) override;

bool SingleColBlock() const override { return true; }
};
} // namespace data
} // namespace xgboost
Expand Down
1 change: 0 additions & 1 deletion src/data/proxy_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix {
MetaInfo const& Info() const override { return info_; }
Context const* Ctx() const override { return &ctx_; }

bool SingleColBlock() const override { return false; }
bool EllpackExists() const override { return false; }
bool GHistIndexExists() const override { return false; }
bool SparsePageExists() const override { return false; }
Expand Down
1 change: 0 additions & 1 deletion src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;
Context const* Ctx() const override { return &fmat_ctx_; }

bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(int num_slices, int slice_id) override;

Expand Down
3 changes: 1 addition & 2 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix {
[[nodiscard]] MetaInfo &Info() override;
[[nodiscard]] const MetaInfo &Info() const override;
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
DMatrix *Slice(common::Span<std::int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;
Expand Down
8 changes: 6 additions & 2 deletions src/data/sparse_page_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
*/
#include "sparse_page_source.h"

#include <filesystem> // for exists
#include <string> // for string
#include <cstdio> // for remove
#include <filesystem> // for exists
#include <numeric> // for partial_sum
#include <string> // for string

namespace xgboost::data {
void Cache::Commit() {
Expand All @@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) {
<< "; you may want to remove it manually";
}
}

#if !defined(XGBOOST_USE_CUDA)
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
#endif
} // namespace xgboost::data
10 changes: 10 additions & 0 deletions src/data/sparse_page_source.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
cuda_impl::Dispatch(proxy,
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
}

void InitNewThread::operator()() const {
*GlobalConfigThreadLocalStore::Get() = config;
// For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the
// stream when creating a new thread in the thread pool. While for CUDA 11.8, this
// action might cause an insufficient driver version error for some reason. Lastly, it
// should work with CUDA 12.5 without any action being taken.

// dh::DefaultStream().Sync();
}
} // namespace xgboost::data
11 changes: 7 additions & 4 deletions src/data/sparse_page_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ class DefaultFormatPolicy {
}
};

struct InitNewThread {
GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get();

void operator()() const;
};

/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
*
Expand Down Expand Up @@ -330,10 +336,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{std::max(2, std::min(nthreads, 16)),
[config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}},
: workers_{std::max(2, std::min(nthreads, 16)), InitNewThread{}},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},
Expand Down
15 changes: 8 additions & 7 deletions tests/cpp/data/test_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,27 @@ TEST(SparsePage, PushCSC) {
}

TEST(SparsePage, PushCSCAfterTranspose) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
bst_idx_t constexpr kRows = 1024, kCols = 21;

auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
const int ncols = dmat->Info().num_col_;
SparsePage page; // Consolidated sparse page
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
SparsePage page; // Consolidated sparse page
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
// Transpose each batch and push
SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest());
page.PushCSC(tmp);
}

// Make sure that the final sparse page has the right number of entries
ASSERT_EQ(kEntries, page.data.Size());
ASSERT_EQ(kRows * kCols, page.data.Size());

page.SortRows(AllThreadsForTest());
auto v = page.GetView();
for (size_t i = 0; i < v.Size(); ++i) {
auto column = v[i];
for (size_t j = 1; j < column.size(); ++j) {
ASSERT_GE(column[j].fvalue, column[j-1].fvalue);
ASSERT_GE(column[j].fvalue, column[j - 1].fvalue);
}
}
}
Expand Down
16 changes: 6 additions & 10 deletions tests/cpp/data/test_ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,11 @@ struct ReadRowFunction {
TEST(EllpackPage, Copy) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr size_t kPageSize = 1024;

// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
Context ctx{MakeCUDACtx(0)};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();

Expand Down Expand Up @@ -187,14 +185,12 @@ TEST(EllpackPage, Copy) {
TEST(EllpackPage, Compact) {
constexpr size_t kRows = 16;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
constexpr size_t kCompactedRows = 8;

// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix> dmat(
CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
Context ctx{MakeCUDACtx(0)};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();

Expand Down
18 changes: 8 additions & 10 deletions tests/cpp/data/test_sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ TEST(SparsePageDMatrix, MetaInfo) {
}

TEST(SparsePageDMatrix, RowAccess) {
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(24);
auto dmat = RandomDataGenerator{12, 6, 0.8f}.Batches(2).GenerateSparsePageDMatrix("temp", false);

// Test the data read into the first row
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
auto page = batch.GetView();
auto first_row = page[0];
ASSERT_EQ(first_row.size(), 3ul);
EXPECT_EQ(first_row[2].index, 2u);
EXPECT_NEAR(first_row[2].fvalue, 0.986566, 1e-4);
ASSERT_EQ(first_row.size(), 1ul);
EXPECT_EQ(first_row[0].index, 5u);
EXPECT_NEAR(first_row[0].fvalue, 0.1805125, 1e-4);
}

TEST(SparsePageDMatrix, ColAccess) {
Expand Down Expand Up @@ -268,11 +268,10 @@ TEST(SparsePageDMatrix, ColAccess) {
}

TEST(SparsePageDMatrix, ThreadSafetyException) {
size_t constexpr kEntriesPerCol = 3;
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
Context ctx;

std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(kEntries);
auto dmat =
RandomDataGenerator{4096, 12, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);

int threads = 1000;

Expand Down Expand Up @@ -304,10 +303,9 @@ TEST(SparsePageDMatrix, ThreadSafetyException) {

// Multi-batches access
TEST(SparsePageDMatrix, ColAccessBatches) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
// Create multiple sparse pages
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
auto dmat =
RandomDataGenerator{1024, 32, 0.4f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
Context ctx;
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
Expand Down
39 changes: 16 additions & 23 deletions tests/cpp/data/test_sparse_page_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,10 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
}

TEST(SparsePageDMatrix, MultipleEllpackPages) {
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, filename);
auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true);

// Loop over the batches and count the records
std::int64_t batch_count = 0;
Expand All @@ -135,15 +132,13 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
EXPECT_EQ(row_count, dmat->Info().num_row_);

auto path =
data::MakeId(filename,
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
".ellpack.page";
data::MakeId("tmep", dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".ellpack.page";
}

TEST(SparsePageDMatrix, RetainEllpackPage) {
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
auto m = CreateSparsePageDMatrix(10000);
auto m = RandomDataGenerator{2048, 4, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);

auto batches = m->GetBatches<EllpackPage>(&ctx, param);
auto begin = batches.begin();
Expand Down Expand Up @@ -278,20 +273,19 @@ struct ReadRowFunction {
};

TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
constexpr size_t kRows = 6;
constexpr size_t kRows = 16;
constexpr size_t kCols = 2;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 1;

// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);

// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto dmat_ext =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);

Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
EXPECT_EQ(impl->base_rowid, 0);
Expand Down Expand Up @@ -325,17 +319,16 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 4096;

// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);

// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto dmat_ext =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);

Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};

size_t current_row = 0;
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/gbm/test_gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ TEST(GBTree, InplacePredictionError) {
p_fmat = rng.GenerateQuantileDMatrix(true);
} else {
#if defined(XGBOOST_USE_CUDA)
p_fmat = rng.GenerateDeviceDMatrix(true);
p_fmat = rng.Device(ctx->Device()).GenerateQuantileDMatrix(true);
#else
CHECK(p_fmat);
#endif // defined(XGBOOST_USE_CUDA)
Expand Down
Loading

0 comments on commit 8d7fe26

Please sign in to comment.