Skip to content

Commit

Permalink
[EM] Avoid synchronous calls and unnecessary ATS access. (#10811)
Browse files Browse the repository at this point in the history
- Pass context into various functions.
- Factor out some CUDA algorithms.
- Use ATS only for update position.
  • Loading branch information
trivialfis authored Sep 10, 2024
1 parent ed5f33d commit d94f667
Show file tree
Hide file tree
Showing 16 changed files with 161 additions and 201 deletions.
38 changes: 36 additions & 2 deletions src/common/algorithm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, V
}

template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
xgboost::common::Span<IdxT> sorted_idx) {
void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
std::size_t bytes = 0;
auto cuctx = ctx->CUDACtx();
dh::Iota(sorted_idx, cuctx->Stream());
Expand Down Expand Up @@ -272,5 +271,40 @@ void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_f
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
}
}

// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit So we don't crash
// on n > 2^31.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename OffsetT>
void InclusiveScan(xgboost::Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
ScanOpT scan_op, OffsetT num_items) {
auto cuctx = ctx->CUDACtx();
std::size_t bytes = 0;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
nullptr, bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
nullptr, bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr, false)));
#endif
dh::TemporaryArray<char> storage(bytes);
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
storage.data().get(), bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType, OffsetT>::Dispatch(
storage.data().get(), bytes, d_in, d_out, scan_op, cub::NullType(), num_items, nullptr,
false)));
#endif
}

template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
OffsetT num_items) {
InclusiveScan(ctx, d_in, d_out, cub::Sum{}, num_items);
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
69 changes: 15 additions & 54 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -372,21 +372,6 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}

template <class Src, class Dst>
void CopyTo(Src const &src, Dst *dst) {
if (src.empty()) {
dst->clear();
return;
}
dst->resize(src.size());
using SVT = std::remove_cv_t<typename Src::value_type>;
using DVT = std::remove_cv_t<typename Dst::value_type>;
static_assert(std::is_same_v<SVT, DVT>,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
src.size() * sizeof(SVT), cudaMemcpyDefault));
}

// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};
Expand Down Expand Up @@ -748,45 +733,6 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
return aggregate;
}

// wrapper to avoid integer `num_items`.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
typename OffsetT>
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
OffsetT num_items) {
size_t bytes = 0;
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr,
false)));
#endif
TemporaryArray<char> storage(bytes);
#if THRUST_MAJOR_VERSION >= 2
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr)));
#else
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr, false)));
#endif
}

template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}

class CUDAStreamView;

class CUDAEvent {
Expand Down Expand Up @@ -857,8 +803,23 @@ class CUDAStream {
[[nodiscard]] cudaStream_t Handle() const { return stream_; }

void Sync() { this->View().Sync(); }
void Wait(CUDAEvent const &e) { this->View().Wait(e); }
};

template <class Src, class Dst>
void CopyTo(Src const &src, Dst *dst, CUDAStreamView stream = DefaultStream()) {
if (src.empty()) {
dst->clear();
return;
}
dst->resize(src.size());
using SVT = std::remove_cv_t<typename Src::value_type>;
using DVT = std::remove_cv_t<typename Dst::value_type>;
static_assert(std::is_same_v<SVT, DVT>, "Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
src.size() * sizeof(SVT), cudaMemcpyDefault, stream));
}

inline auto CachingThrustPolicy() {
XGBCachingDeviceAllocator<char> alloc;
#if THRUST_MAJOR_VERSION >= 2 || defined(XGBOOST_USE_RMM)
Expand Down
4 changes: 2 additions & 2 deletions src/common/ranking_utils.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <thrust/functional.h> // for maximum
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
Expand Down Expand Up @@ -158,7 +158,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
if (param_.HasTruncation()) {
n_cuda_threads_ =
common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair());
common::SegmentedTrapezoidThreads(ctx, d_group_ptr, d_threads_group_ptr, Param().NumPair());
} else {
auto n_pairs = Param().NumPair();
dh::LaunchN(n_groups, cuctx->Stream(),
Expand Down
24 changes: 11 additions & 13 deletions src/common/threading_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/**
* Copyright 2021-2023 by XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_THREADING_UTILS_CUH_
#define XGBOOST_COMMON_THREADING_UTILS_CUH_

#include <algorithm> // std::min
#include <cstddef> // std::size_t
#include <algorithm> // std::min
#include <cstddef> // std::size_t

#include "./math.h" // Sqr
#include "common.h"
#include "algorithm.cuh" // for InclusiveSum
#include "common.h" // for safe_cuda
#include "device_helpers.cuh" // LaunchN
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/span.h" // Span

namespace xgboost {
namespace common {
namespace xgboost::common {
/**
* \param n Number of items (length of the base)
* \param h hight
Expand Down Expand Up @@ -43,9 +43,8 @@ XGBOOST_DEVICE inline std::size_t DiscreteTrapezoidArea(std::size_t n, std::size
* with h <= n
*/
template <typename U>
std::size_t SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
xgboost::common::Span<std::size_t> out_group_threads_ptr,
std::size_t h) {
std::size_t SegmentedTrapezoidThreads(Context const *ctx, Span<U> group_ptr,
Span<std::size_t> out_group_threads_ptr, std::size_t h) {
CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
dh::LaunchN(group_ptr.size(), [=] XGBOOST_DEVICE(std::size_t idx) {
Expand All @@ -57,8 +56,8 @@ std::size_t SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
std::size_t cnt = static_cast<std::size_t>(group_ptr[idx] - group_ptr[idx - 1]);
out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
});
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
InclusiveSum(ctx, out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
std::size_t total = 0;
dh::safe_cuda(cudaMemcpy(&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost));
Expand All @@ -82,6 +81,5 @@ XGBOOST_DEVICE inline void UnravelTrapeziodIdx(std::size_t i_idx, std::size_t n,

j = idx - n_elems + i + 1;
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
#endif // XGBOOST_COMMON_THREADING_UTILS_CUH_
25 changes: 1 addition & 24 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,30 +254,7 @@ void CopyDataToEllpack(Context const* ctx, const AdapterBatchT& batch,
d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid};
thrust::transform_output_iterator<decltype(functor), decltype(discard)> out(discard, functor);

// Go one level down into cub::DeviceScan API to set OffsetT as 64 bit
// So we don't crash on n > 2^31
size_t temp_storage_bytes = 0;
using DispatchScan = cub::DispatchScan<decltype(key_value_index_iter), decltype(out),
TupleScanOp<Tuple>, cub::NullType, std::int64_t>;
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
ctx->CUDACtx()->Stream()));
#else
DispatchScan::Dispatch(nullptr, temp_storage_bytes, key_value_index_iter, out,
TupleScanOp<Tuple>(), cub::NullType(), batch.Size(),
nullptr, false);
#endif
dh::TemporaryArray<char> temp_storage(temp_storage_bytes);
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda(DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), ctx->CUDACtx()->Stream()));
#else
DispatchScan::Dispatch(temp_storage.data().get(), temp_storage_bytes,
key_value_index_iter, out, TupleScanOp<Tuple>(),
cub::NullType(), batch.Size(), nullptr, false);
#endif
common::InclusiveScan(ctx, key_value_index_iter, out, TupleScanOp<Tuple>{}, batch.Size());
}

void WriteNullValues(Context const* ctx, EllpackPageImpl* dst, common::Span<size_t> row_counts) {
Expand Down
24 changes: 12 additions & 12 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <utility>

#include "../collective/allreduce.h"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/algorithm.cuh" // SegmentedArgSort, InclusiveScan
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
#include "auc.h"
Expand Down Expand Up @@ -128,8 +128,8 @@ std::tuple<double, double, double> GPUBinaryAUC(Context const *ctx,
dh::tbegin(d_unique_idx));
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));

dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp),
PairPlus<double, double>{}, d_fptp.size());
common::InclusiveScan(ctx, dh::tbegin(d_fptp), dh::tbegin(d_fptp), PairPlus<double, double>{},
d_fptp.size());

auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values
Expand Down Expand Up @@ -239,7 +239,7 @@ double ScaleClasses(Context const *ctx, bool is_column_split, common::Span<doubl
* getting class id or group id given scan index.
*/
template <typename Fn>
void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) {
void SegmentedFPTP(Context const *ctx, common::Span<Pair> d_fptp, Fn segment_id) {
using Triple = thrust::tuple<uint32_t, double, double>;
// expand to tuple to include idx
auto fptp_it_in = dh::MakeTransformIterator<Triple>(
Expand All @@ -253,8 +253,8 @@ void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) {
thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
return t;
});
dh::InclusiveScan(
fptp_it_in, fptp_it_out,
common::InclusiveScan(
ctx, fptp_it_in, fptp_it_out,
[=] XGBOOST_DEVICE(Triple const &l, Triple const &r) {
uint32_t l_gid = segment_id(thrust::get<0>(l));
uint32_t r_gid = segment_id(thrust::get<0>(r));
Expand Down Expand Up @@ -391,7 +391,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
d_unique_idx = d_unique_idx.subspan(0, n_uniques);

auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; };
SegmentedFPTP(d_fptp, get_class_id);
SegmentedFPTP(ctx, d_fptp, get_class_id);

// scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
Expand Down Expand Up @@ -528,8 +528,8 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
dh::caching_device_vector<size_t> threads_group_ptr(group_ptr.size(), 0);
auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr);
// Use max to represent triangle
auto n_threads = common::SegmentedTrapezoidThreads(
d_group_ptr, d_threads_group_ptr, std::numeric_limits<size_t>::max());
auto n_threads = common::SegmentedTrapezoidThreads(ctx, d_group_ptr, d_threads_group_ptr,
std::numeric_limits<std::size_t>::max());
CHECK_LT(n_threads, std::numeric_limits<int32_t>::max());
// get the coordinate in nested summation
auto get_i_j = [=]XGBOOST_DEVICE(size_t idx, size_t query_group_idx) {
Expand Down Expand Up @@ -591,8 +591,8 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
}
return {}; // discard
});
dh::InclusiveScan(
in, out,
common::InclusiveScan(
ctx, in, out,
[] XGBOOST_DEVICE(RankScanItem const &l, RankScanItem const &r) {
if (l.group_id != r.group_id) {
return r;
Expand Down Expand Up @@ -774,7 +774,7 @@ std::pair<double, uint32_t> GPURankingPRAUCImpl(Context const *ctx,
auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) {
return dh::SegmentId(d_group_ptr, idx);
};
SegmentedFPTP(d_fptp, get_group_id);
SegmentedFPTP(ctx, d_fptp, get_group_id);

// scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
Expand Down
8 changes: 4 additions & 4 deletions src/metric/elementwise_metric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <cmath>
#include <numeric> // for accumulate

#include "../common/common.h" // for AssertGPUSupport
#include "../common/math.h"
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/pseudo_huber.h"
Expand All @@ -28,7 +27,9 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform_reduce.h>

#include "../common/device_helpers.cuh"
#include "../common/cuda_context.cuh" // for CUDAContext
#else
#include "../common/common.h" // for AssertGPUSupport
#endif // XGBOOST_USE_CUDA

namespace xgboost::metric {
Expand All @@ -48,11 +49,10 @@ PackedReduceResult Reduce(Context const* ctx, MetaInfo const& info, Fn&& loss) {
auto labels = info.labels.View(ctx->Device());
if (ctx->IsCUDA()) {
#if defined(XGBOOST_USE_CUDA)
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::counting_iterator<size_t> begin(0);
thrust::counting_iterator<size_t> end = begin + labels.Size();
result = thrust::transform_reduce(
thrust::cuda::par(alloc), begin, end,
ctx->CUDACtx()->CTP(), begin, end,
[=] XGBOOST_DEVICE(size_t i) {
auto idx = linalg::UnravelIndex(i, labels.Shape());
auto sample_id = std::get<0>(idx);
Expand Down
13 changes: 7 additions & 6 deletions src/tree/constraints.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
#include <thrust/execution_policy.h>
#include <thrust/iterator/counting_iterator.h>

#include <string>
#include <set>
#include <string>

#include "xgboost/logging.h"
#include "xgboost/span.h"
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh"
#include "constraints.cuh"
#include "param.h"
#include "../common/device_helpers.cuh"
#include "xgboost/logging.h"
#include "xgboost/span.h"

namespace xgboost {

Expand Down Expand Up @@ -130,9 +131,9 @@ FeatureInteractionConstraintDevice::FeatureInteractionConstraintDevice(
this->Configure(param, n_features);
}

void FeatureInteractionConstraintDevice::Reset() {
void FeatureInteractionConstraintDevice::Reset(Context const* ctx) {
for (auto& node : node_constraints_storage_) {
thrust::fill(node.begin(), node.end(), 0);
thrust::fill(ctx->CUDACtx()->CTP(), node.begin(), node.end(), 0);
}
}

Expand Down
Loading

0 comments on commit d94f667

Please sign in to comment.