Skip to content

Commit

Permalink
[EM] Merge GPU partitioning with histogram building. (#10766)
Browse files Browse the repository at this point in the history
- Stop concatenating pages if there's no subsampling.
- Use a single iteration for histogram build and partitioning.
  • Loading branch information
trivialfis authored Aug 30, 2024
1 parent 98ac153 commit e1a2c1b
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 159 deletions.
10 changes: 6 additions & 4 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def check_extmem_qdm(
Xy = xgb.QuantileDMatrix(X, y, weight=w)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)

cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
if device == "cpu":
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])

predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
Expand Down
22 changes: 2 additions & 20 deletions src/tree/gpu_hist/gradient_based_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,10 @@ GradientBasedSample NoSampling::Sample(Context const*, common::Span<GradientPair
ExternalMemoryNoSampling::ExternalMemoryNoSampling(BatchParam batch_param)
: batch_param_{std::move(batch_param)} {}

GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
GradientBasedSample ExternalMemoryNoSampling::Sample(Context const*,
common::Span<GradientPair> gpair,
DMatrix* p_fmat) {
std::shared_ptr<EllpackPage> new_page;
if (!page_concatenated_) {
// Concatenate all the external memory ELLPACK pages into a single in-memory page.
bst_idx_t offset = 0;
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx, batch_param_)) {
auto page = batch.Impl();
if (!new_page) {
new_page = std::make_shared<EllpackPage>();
*new_page->Impl() = EllpackPageImpl(ctx, page->CutsShared(), page->is_dense,
page->row_stride, p_fmat->Info().num_row_);
}
bst_idx_t num_elements = new_page->Impl()->Copy(ctx, page, offset);
offset += num_elements;
}
page_concatenated_ = true;
this->p_fmat_new_ =
std::make_unique<data::IterativeDMatrix>(new_page, p_fmat->Info(), batch_param_);
}
return {this->p_fmat_new_.get(), gpair};
return {p_fmat, gpair};
}

UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
Expand Down
2 changes: 0 additions & 2 deletions src/tree/gpu_hist/gradient_based_sampler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class ExternalMemoryNoSampling : public SamplingStrategy {

private:
BatchParam batch_param_;
std::unique_ptr<DMatrix> p_fmat_new_{nullptr};
bool page_concatenated_{false};
};

/*! \brief Uniform sampling in in-memory mode. */
Expand Down
4 changes: 4 additions & 0 deletions src/tree/gpu_hist/row_partitioner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
NodePositionInfo{Segment{0, static_cast<cuda_impl::RowIndexT>(n_samples)}});

thrust::sequence(ctx->CUDACtx()->CTP(), ridx_.data(), ridx_.data() + ridx_.size(), base_rowid);

// Pre-allocate some host memory
this->pinned_.GetSpan<std::int32_t>(1 << 11);
this->pinned2_.GetSpan<std::int32_t>(1 << 13);
}

RowPartitioner::~RowPartitioner() = default;
Expand Down
197 changes: 93 additions & 104 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ struct GPUHistMakerDevice {

// Reset values for each update iteration
[[nodiscard]] DMatrix* Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* p_fmat) {
this->monitor.Start(__func__);
auto const& info = p_fmat->Info();
this->column_sampler_->Init(ctx_, p_fmat->Info().num_col_, info.feature_weights.HostVector(),
param.colsample_bynode, param.colsample_bylevel,
Expand Down Expand Up @@ -252,7 +253,7 @@ struct GPUHistMakerDevice {
this->histogram_.Reset(ctx_, this->hist_param_->MaxCachedHistNodes(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), cuts_->TotalBins(),
false);

this->monitor.Stop(__func__);
return p_fmat;
}

Expand Down Expand Up @@ -346,6 +347,38 @@ struct GPUHistMakerDevice {
monitor.Stop(__func__);
}

void ReduceHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
std::vector<bst_node_t> const& build_nidx,
std::vector<bst_node_t> const& subtraction_nidx) {
if (candidates.empty()) {
return;
}
this->monitor.Start(__func__);

// Reduce all in one go
// This gives much better latency in a distributed setting when processing a large batch
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), build_nidx.at(0), build_nidx.size());
// Perform subtraction for sibiling nodes
auto need_build = this->histogram_.SubtractHist(candidates, build_nidx, subtraction_nidx);
if (need_build.empty()) {
this->monitor.Stop(__func__);
return;
}

// Build the nodes that can not obtain the histogram using subtraction. This is the slow path.
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
for (auto nidx : need_build) {
this->BuildHist(page, k, nidx);
}
++k;
}
for (auto nidx : need_build) {
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), nidx, 1);
}
this->monitor.Stop(__func__);
}

void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix,
std::vector<NodeSplitData> const& split_data,
std::vector<bst_node_t> const& nidx,
Expand Down Expand Up @@ -434,56 +467,74 @@ struct GPUHistMakerDevice {
}
};

void UpdatePosition(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
RegTree* p_tree) {
if (candidates.empty()) {
// Update position and build histogram.
void PartitionAndBuildHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& expand_set,
std::vector<GPUExpandEntry> const& candidates, RegTree const* p_tree) {
if (expand_set.empty()) {
return;
}

monitor.Start(__func__);
CHECK_LE(candidates.size(), expand_set.size());

auto [nidx, left_nidx, right_nidx, split_data] = this->CreatePartitionNodes(p_tree, candidates);
// Update all the nodes if working with external memory, this saves us from working
// with the finalize position call, which adds an additional iteration and requires
// special handling for row index.
bool const is_single_block = p_fmat->SingleColBlock();

for (size_t i = 0; i < candidates.size(); i++) {
auto const& e = candidates[i];
RegTree::Node const& split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(e.nid);
nidx[i] = e.nid;
left_nidx[i] = split_node.LeftChild();
right_nidx[i] = split_node.RightChild();
split_data[i] = NodeSplitData{split_node, split_type, evaluator_.GetDeviceNodeCats(e.nid)};
// Prepare for update partition
auto [nidx, left_nidx, right_nidx, split_data] =
this->CreatePartitionNodes(p_tree, is_single_block ? candidates : expand_set);

CHECK_EQ(split_type == FeatureType::kCategorical, e.split.is_cat);
}
// Prepare for build hist
std::vector<bst_node_t> build_nidx(candidates.size());
std::vector<bst_node_t> subtraction_nidx(candidates.size());
auto prefetch_copy =
AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx);

CHECK_EQ(p_fmat->NumBatches(), 1);
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx);

monitor.Start("Partition-BuildHist");

std::int32_t k{0};
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
auto d_matrix = page.Impl()->GetDeviceAccessor(ctx_->Device());
auto go_left = GoLeftOp{d_matrix};

// Partition histogram.
monitor.Start("UpdatePositionBatch");
if (p_fmat->Info().IsColumnSplit()) {
UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx);
monitor.Stop(__func__);
return;
} else {
partitioners_.at(k)->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
const NodeSplitData& data) { return go_left(ridx, data); });
}
auto go_left = GoLeftOp{d_matrix};
partitioners_.front()->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(cuda_impl::RowIndexT ridx, int /*nidx_in_batch*/,
const NodeSplitData& data) { return go_left(ridx, data); });
monitor.Stop("UpdatePositionBatch");

for (auto nidx : build_nidx) {
this->BuildHist(page, k, nidx);
}

++k;
}

monitor.Stop("Partition-BuildHist");

this->ReduceHist(p_fmat, candidates, build_nidx, subtraction_nidx);

monitor.Stop(__func__);
}

// After tree update is finished, update the position of all training
// instances to their final leaf. This information is used later to update the
// prediction cache
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task, bst_idx_t n_samples,
void FinalisePosition(DMatrix* p_fmat, RegTree const* p_tree, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) {
if (!p_fmat->SingleColBlock() && task.UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective function can not be used with external memory.";
}
if (p_fmat->Info().num_row_ != n_samples) {
if (static_cast<std::size_t>(p_fmat->NumBatches() + 1) != this->batch_ptr_.size()) {
// External memory with concatenation. Not supported.
p_out_position->Resize(0);
positions_.clear();
Expand Down Expand Up @@ -577,60 +628,6 @@ struct GPUHistMakerDevice {
return true;
}

/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& candidates,
const RegTree& tree) {
if (candidates.empty()) {
return;
}
this->monitor.Start(__func__);
// Some nodes we will manually compute histograms
// others we will do by subtraction
std::vector<bst_node_t> hist_nidx(candidates.size());
std::vector<bst_node_t> subtraction_nidx(candidates.size());
auto prefetch_copy =
AssignNodes(&tree, this->quantiser.get(), candidates, hist_nidx, subtraction_nidx);

std::vector<int> all_new = hist_nidx;
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
// Allocate the histograms
// Guaranteed contiguous memory
histogram_.AllocateHistograms(ctx_, all_new);

std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
for (auto nidx : hist_nidx) {
this->BuildHist(page, k, nidx);
}
++k;
}

// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), hist_nidx.at(0), hist_nidx.size());

for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid;

if (!this->histogram_.SubtractionTrick(parent_nidx, build_hist_nidx,
subtraction_trick_nidx)) {
// Calculate other histogram manually
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page, k, subtraction_trick_nidx);
++k;
}
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), subtraction_trick_nidx, 1);
}
}
this->monitor.Stop(__func__);
}

void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;

Expand Down Expand Up @@ -681,8 +678,9 @@ struct GPUHistMakerDevice {
}

GPUExpandEntry InitRoot(DMatrix* p_fmat, RegTree* p_tree) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
this->monitor.Start(__func__);

constexpr bst_node_t kRootNIdx = RegTree::kRoot;
auto quantiser = *this->quantiser;
auto gpair_it = dh::MakeTransformIterator<GradientPairInt64>(
dh::tbegin(gpair),
Expand All @@ -697,6 +695,7 @@ struct GPUHistMakerDevice {

histogram_.AllocateHistograms(ctx_, {kRootNIdx});
std::int32_t k = 0;
CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size());
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page, k, kRootNIdx);
++k;
Expand All @@ -712,25 +711,18 @@ struct GPUHistMakerDevice {

// Generate first split
auto root_entry = this->EvaluateRootSplit(p_fmat, root_sum_quantised);

this->monitor.Stop(__func__);
return root_entry;
}

void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
bool const is_single_block = p_fmat->SingleColBlock();
bst_idx_t const n_samples = p_fmat->Info().num_row_;

auto& tree = *p_tree;
// Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32);

monitor.Start("Reset");
p_fmat = this->Reset(gpair_all, p_fmat);
monitor.Stop("Reset");

monitor.Start("InitRoot");
driver.Push({this->InitRoot(p_fmat, p_tree)});
monitor.Stop("InitRoot");

// The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop();
Expand All @@ -740,20 +732,17 @@ struct GPUHistMakerDevice {
}
// Get the candidates we are allowed to expand further
// e.g. We do not bother further processing nodes whose children are beyond max depth
std::vector<GPUExpandEntry> filtered_expand_set;
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
[&](const auto& e) { return driver.IsChildValid(e); });
std::vector<GPUExpandEntry> valid_candidates;
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(valid_candidates),
[&](auto const& e) { return driver.IsChildValid(e); });

// Allocaate children nodes.
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry{});
// Update all the nodes if working with external memory, this saves us from working
// with the finalize position call, which adds an additional iteration and requires
// special handling for row index.
this->UpdatePosition(p_fmat, is_single_block ? filtered_expand_set : expand_set, p_tree);
pinned.GetSpan<GPUExpandEntry>(valid_candidates.size() * 2, GPUExpandEntry());

this->BuildHistLeftRight(p_fmat, filtered_expand_set, tree);
this->PartitionAndBuildHist(p_fmat, expand_set, valid_candidates, p_tree);

this->EvaluateSplits(p_fmat, filtered_expand_set, *p_tree, new_candidates);
this->EvaluateSplits(p_fmat, valid_candidates, *p_tree, new_candidates);
dh::DefaultStream().Sync();

driver.Push(new_candidates.begin(), new_candidates.end());
Expand All @@ -764,10 +753,10 @@ struct GPUHistMakerDevice {
// be spliable before evaluation but invalid after evaluation as we have more
// restrictions like min loss change after evalaution. Therefore, the check condition
// is greater than or equal to.
if (is_single_block) {
if (p_fmat->SingleColBlock()) {
CHECK_GE(p_tree->NumNodes(), this->partitioners_.front()->GetNumNodes());
}
this->FinalisePosition(p_fmat, p_tree, *task, n_samples, p_out_position);
this->FinalisePosition(p_fmat, p_tree, *task, p_out_position);
}
};

Expand Down
Loading

0 comments on commit e1a2c1b

Please sign in to comment.