Skip to content

Commit

Permalink
build fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Razdoburdin committed Sep 9, 2024
1 parent 51e2725 commit 45b16ca
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 0 deletions.
27 changes: 27 additions & 0 deletions plugin/sycl/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ template void InitHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);

/*!
* \brief Copy histogram from src to dst
*/
template<typename GradientSumT>
void CopyHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src,
size_t size) {
GradientSumT* pdst = reinterpret_cast<GradientSumT*>(dst->Data());
const GradientSumT* psrc = reinterpret_cast<const GradientSumT*>(src.DataConst());

qu.submit([&](::sycl::handler& cgh) {
cgh.parallel_for<>(::sycl::range<1>(2 * size), [=](::sycl::item<1> pid) {
const size_t i = pid.get_id(0);
pdst[i] = psrc[i];
});
}).wait();
}
template void CopyHist(::sycl::queue qu,
GHistRow<float, MemoryType::on_device>* dst,
const GHistRow<float, MemoryType::on_device>& src,
size_t size);
template void CopyHist(::sycl::queue qu,
GHistRow<double, MemoryType::on_device>* dst,
const GHistRow<double, MemoryType::on_device>& src,
size_t size);

/*!
* \brief Compute Subtraction: dst = src1 - src2
*/
Expand Down
9 changes: 9 additions & 0 deletions plugin/sycl/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ void InitHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* hist,
size_t size, ::sycl::event* event);

/*!
* \brief Copy histogram from src to dst
*/
template<typename GradientSumT>
void CopyHist(::sycl::queue qu,
GHistRow<GradientSumT, MemoryType::on_device>* dst,
const GHistRow<GradientSumT, MemoryType::on_device>& src,
size_t size);

/*!
* \brief Compute subtraction: dst = src1 - src2
*/
Expand Down
27 changes: 27 additions & 0 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,32 @@ using ::sycl::ext::oneapi::plus;
using ::sycl::ext::oneapi::minimum;
using ::sycl::ext::oneapi::maximum;

template <typename GradientSumT>
void HistUpdater<GradientSumT>::ReduceHists(const std::vector<int>& sync_ids,
size_t nbins) {
if (reduce_buffer_.size() < sync_ids.size() * nbins) {
reduce_buffer_.resize(sync_ids.size() * nbins);
}
for (size_t i = 0; i < sync_ids.size(); i++) {
auto& this_hist = hist_[sync_ids[i]];
const GradientPairT* psrc = reinterpret_cast<const GradientPairT*>(this_hist.DataConst());
// std::copy(psrc, psrc + nbins, reduce_buffer.begin() + i * nbins);
qu_.memcpy(reduce_buffer_.data() + i * nbins, psrc, nbins*sizeof(GradientPairT)).wait();
}

auto buffer_vec = linalg::MakeVec(reinterpret_cast<GradientSumT*>(reduce_buffer_.data()),
2 * nbins * sync_ids.size());
auto rc = collective::Allreduce(ctx_, buffer_vec, collective::Op::kSum);
SafeColl(rc);

for (size_t i = 0; i < sync_ids.size(); i++) {
auto& this_hist = hist_[sync_ids[i]];
GradientPairT* psrc = reinterpret_cast<GradientPairT*>(this_hist.Data());
qu_.memcpy(psrc, reduce_buffer_.data() + i * nbins, nbins*sizeof(GradientPairT)).wait();
// std::copy(reduce_buffer.begin() + i * nbins, reduce_buffer.begin() + (i + 1) * nbins, psrc);
}
}

template <typename GradientSumT>
void HistUpdater<GradientSumT>::SetHistSynchronizer(
HistSynchronizer<GradientSumT> *sync) {
Expand Down Expand Up @@ -492,6 +518,7 @@ void HistUpdater<GradientSumT>::InitData(
// initialize histogram collection
uint32_t nbins = gmat.cut.Ptrs().back();
hist_.Init(qu_, nbins);
hist_local_worker_.Init(qu_, nbins);

hist_buffer_.Init(qu_, nbins);
size_t buffer_size = kBufferSize;
Expand Down
6 changes: 6 additions & 0 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class HistUpdater {
RegTree* p_tree,
const USMVector<GradientPair, MemoryType::on_device>& gpair);

void ReduceHists(const std::vector<int>& sync_ids, size_t nbins);

inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
if (lhs.GetLossChange() == rhs.GetLossChange()) {
return lhs.GetNodeId() > rhs.GetNodeId(); // favor small timestamp
Expand Down Expand Up @@ -233,6 +235,8 @@ class HistUpdater {
common::ParallelGHistBuilder<GradientSumT> hist_buffer_;
/*! \brief culmulative histogram of gradients. */
common::HistCollection<GradientSumT, MemoryType::on_device> hist_;
/*! \brief culmulative local parent histogram of gradients. */
common::HistCollection<GradientSumT, MemoryType::on_device> hist_local_worker_;

/*! \brief TreeNode Data: statistics for each constructed node */
std::vector<NodeEntry<GradientSumT>> snode_host_;
Expand Down Expand Up @@ -261,6 +265,8 @@ class HistUpdater {
USMVector<bst_float, MemoryType::on_device> out_preds_buf_;
bst_float* out_pred_ptr = nullptr;

std::vector<GradientPairT> reduce_buffer_;

::sycl::queue qu_;
};

Expand Down

0 comments on commit 45b16ca

Please sign in to comment.