diff --git a/plugin/sycl/tree/hist_row_adder.h b/plugin/sycl/tree/hist_row_adder.h index 968bcca737dc..93650d5d0746 100644 --- a/plugin/sycl/tree/hist_row_adder.h +++ b/plugin/sycl/tree/hist_row_adder.h @@ -39,6 +39,42 @@ class BatchHistRowsAdder: public HistRowsAdder { } }; + +template +class DistributedHistRowsAdder: public HistRowsAdder { + public: + void AddHistRows(HistUpdater* builder, + std::vector* sync_ids, RegTree *p_tree) override { + builder->builder_monitor_.Start("AddHistRows"); + const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size(); + const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size(); + std::vector merged_node_ids(explicit_size + subtaction_size); + for (size_t i = 0; i < explicit_size; ++i) { + merged_node_ids[i] = builder->nodes_for_explicit_hist_build_[i].nid; + } + for (size_t i = 0; i < subtaction_size; ++i) { + merged_node_ids[explicit_size + i] = + builder->nodes_for_subtraction_trick_[i].nid; + } + std::sort(merged_node_ids.begin(), merged_node_ids.end()); + sync_ids->clear(); + for (auto const& nid : merged_node_ids) { + if ((*p_tree)[nid].IsLeftChild()) { + builder->hist_.AddHistRow(nid); + builder->hist_local_worker_.AddHistRow(nid); + sync_ids->push_back(nid); + } + } + for (auto const& nid : merged_node_ids) { + if (!((*p_tree)[nid].IsLeftChild())) { + builder->hist_.AddHistRow(nid); + builder->hist_local_worker_.AddHistRow(nid); + } + } + builder->builder_monitor_.Stop("AddHistRows"); + } +}; + } // namespace tree } // namespace sycl } // namespace xgboost diff --git a/plugin/sycl/tree/hist_synchronizer.h b/plugin/sycl/tree/hist_synchronizer.h index 2275a51dba37..c89215cf85d2 100644 --- a/plugin/sycl/tree/hist_synchronizer.h +++ b/plugin/sycl/tree/hist_synchronizer.h @@ -61,6 +61,68 @@ class BatchHistSynchronizer: public HistSynchronizer { std::vector<::sycl::event> hist_sync_events_; }; +template +class DistributedHistSynchronizer: public HistSynchronizer { + public: + void SyncHistograms(HistUpdater* builder, + const std::vector& sync_ids, + RegTree *p_tree) override { + builder->builder_monitor_.Start("SyncHistograms"); + const size_t nbins = builder->hist_builder_.GetNumBins(); + for (int node = 0; node < builder->nodes_for_explicit_hist_build_.size(); node++) { + const auto entry = builder->nodes_for_explicit_hist_build_[node]; + auto& this_hist = builder->hist_[entry.nid]; + // // Store posible parent node + auto& this_local = builder->hist_local_worker_[entry.nid]; + common::CopyHist(builder->qu_, &this_local, this_hist, nbins); + + if (!(*p_tree)[entry.nid].IsRoot()) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto sibling_nid = entry.GetSiblingId(p_tree, parent_id); + auto& parent_hist = builder->hist_local_worker_[parent_id]; + + auto& sibling_hist = builder->hist_[sibling_nid]; + common::SubtractionHist(builder->qu_, &sibling_hist, parent_hist, + this_hist, nbins, ::sycl::event()); + builder->qu_.wait_and_throw(); + // Store posible parent node + auto& sibling_local = builder->hist_local_worker_[sibling_nid]; + common::CopyHist(builder->qu_, &sibling_local, sibling_hist, nbins); + } + } + builder->ReduceHists(sync_ids, nbins); + + ParallelSubtractionHist(builder, builder->nodes_for_explicit_hist_build_, p_tree); + ParallelSubtractionHist(builder, builder->nodes_for_subtraction_trick_, p_tree); + + builder->builder_monitor_.Stop("SyncHistograms"); + } + + void ParallelSubtractionHist(HistUpdater* builder, + const std::vector& nodes, + const RegTree * p_tree) { + const size_t nbins = builder->hist_builder_.GetNumBins(); + for (int node = 0; node < nodes.size(); node++) { + const auto entry = nodes[node]; + if (!((*p_tree)[entry.nid].IsLeftChild())) { + auto& this_hist = builder->hist_[entry.nid]; + + if (!(*p_tree)[entry.nid].IsRoot()) { + const size_t parent_id = (*p_tree)[entry.nid].Parent(); + auto& parent_hist = builder->hist_[parent_id]; + auto& sibling_hist = builder->hist_[entry.GetSiblingId(p_tree, parent_id)]; + common::SubtractionHist(builder->qu_, &this_hist, parent_hist, + sibling_hist, nbins, ::sycl::event()); + builder->qu_.wait_and_throw(); + } + } + } + } + + private: + std::vector<::sycl::event> hist_sync_events_; +}; + } // namespace tree } // namespace sycl } // namespace xgboost diff --git a/plugin/sycl/tree/hist_updater.h b/plugin/sycl/tree/hist_updater.h index fd5fdda9433d..bb99dac471c6 100644 --- a/plugin/sycl/tree/hist_updater.h +++ b/plugin/sycl/tree/hist_updater.h @@ -87,7 +87,10 @@ class HistUpdater { protected: friend class BatchHistSynchronizer; + friend class DistributedHistSynchronizer; + friend class BatchHistRowsAdder; + friend class DistributedHistRowsAdder; struct SplitQuery { bst_node_t nid; diff --git a/plugin/sycl/tree/updater_quantile_hist.cc b/plugin/sycl/tree/updater_quantile_hist.cc index ee7a7ad0f101..030e850f4cd2 100644 --- a/plugin/sycl/tree/updater_quantile_hist.cc +++ b/plugin/sycl/tree/updater_quantile_hist.cc @@ -51,7 +51,8 @@ void QuantileHistMaker::SetPimpl(std::unique_ptr>* pim param_, int_constraint_, dmat)); if (collective::IsDistributed()) { - LOG(FATAL) << "Distributed mode is not yet upstreamed for sycl"; + (*pimpl)->SetHistSynchronizer(new DistributedHistSynchronizer()); + (*pimpl)->SetHistRowsAdder(new DistributedHistRowsAdder()); } else { (*pimpl)->SetHistSynchronizer(new BatchHistSynchronizer()); (*pimpl)->SetHistRowsAdder(new BatchHistRowsAdder()); diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 8f6e560e4a8c..c8a87edc0449 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -306,11 +306,12 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: raise TypeError(msg) if device and device.find(":") != -1: - raise ValueError( - "Distributed training doesn't support selecting device ordinal as GPUs are" - " managed by the distributed frameworks. use `device=cuda` or `device=gpu`" - " instead." - ) + if device != "sycl:gpu": + raise ValueError( + "Distributed training doesn't support selecting device ordinal as GPUs are" + " managed by the distributed frameworks. use `device=cuda` or `device=gpu`" + " instead." + ) if kwargs.get("booster", None) == "gblinear": raise NotImplementedError( diff --git a/tests/python-sycl/test_sycl_simple_dask.py b/tests/python-sycl/test_sycl_simple_dask.py new file mode 100644 index 000000000000..19eebebee3e5 --- /dev/null +++ b/tests/python-sycl/test_sycl_simple_dask.py @@ -0,0 +1,42 @@ +from xgboost import dask as dxgb +from xgboost import testing as tm + +from hypothesis import given, strategies, assume, settings, note + +import dask.array as da +import dask.distributed + + +def train_result(client, param, dtrain, num_rounds): + result = dxgb.train( + client, + param, + dtrain, + num_rounds, + verbose_eval=False, + evals=[(dtrain, "train")], + ) + return result + + +class TestSYCLDask: + # The simplest test verify only one node training. + def test_simple(self): + cluster = dask.distributed.LocalCluster(n_workers=1) + client = dask.distributed.Client(cluster) + + param = {} + param["tree_method"] = "hist" + param["device"] = "sycl" + param["verbosity"] = 0 + param["objective"] = "reg:squarederror" + + # X and y must be Dask dataframes or arrays + num_obs = 1e4 + num_features = 20 + X = da.random.random(size=(num_obs, num_features), chunks=(1000, num_features)) + y = da.random.random(size=(num_obs, 1), chunks=(1000, 1)) + dtrain = dxgb.DaskDMatrix(client, X, y) + + result = train_result(client, param, dtrain, 10) + assert tm.non_increasing(result["history"]["train"]["rmse"])