Skip to content

Commit

Permalink
[SYCL] Fix for sycl support with sklearn estimators (#10806)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Sep 9, 2024
1 parent 5f7f31d commit bba6aa7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
6 changes: 5 additions & 1 deletion python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,14 +1568,15 @@ def inplace_predict( # pylint: disable=unused-argument

async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"],
device: Optional[str],
tree_method: Optional[str],
max_bin: Optional[int],
**kwargs: Any,
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""

def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if _can_use_qdm(tree_method):
if _can_use_qdm(tree_method, device):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)
Expand Down Expand Up @@ -1776,6 +1777,7 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
Expand Down Expand Up @@ -1865,6 +1867,7 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
Expand Down Expand Up @@ -2067,6 +2070,7 @@ async def _fit_async(
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
device=self.device,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
Expand Down
7 changes: 4 additions & 3 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def _check_rf_callback(
)


def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto")
def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
not_sycl = (device is None) or (not device.startswith("sycl"))
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl


class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -1031,7 +1032,7 @@ def _duplicated(parameter: str) -> None:

def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if _can_use_qdm(self.tree_method) and self.booster != "gblinear":
if _can_use_qdm(self.tree_method, self.device) and self.booster != "gblinear":
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
Expand Down
5 changes: 4 additions & 1 deletion python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,10 @@ def _train_booster(
context = BarrierTaskContext.get()

dev_ordinal = None
use_qdm = _can_use_qdm(booster_params.get("tree_method", None))
use_qdm = _can_use_qdm(
booster_params.get("tree_method", None),
booster_params.get("device", None),
)
verbosity = booster_params.get("verbosity", 1)
msg = "Training on CPUs"
if run_on_gpu:
Expand Down
37 changes: 37 additions & 0 deletions tests/python-sycl/test_sycl_with_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import xgboost as xgb
import pytest
import sys
import numpy as np

from xgboost import testing as tm

sys.path.append("tests/python")
import test_with_sklearn as twskl # noqa

pytestmark = pytest.mark.skipif(**tm.no_sklearn())

rng = np.random.RandomState(1994)


def test_sycl_binary_classification():
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold

digits = load_digits(n_class=2)
y = digits["target"]
X = digits["data"]
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for cls in (xgb.XGBClassifier, xgb.XGBRFClassifier):
for train_index, test_index in kf.split(X, y):
xgb_model = cls(random_state=42, device="sycl", n_estimators=4).fit(
X[train_index], y[train_index]
)
preds = xgb_model.predict(X[test_index])
labels = y[test_index]
err = sum(
1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]
) / float(len(preds))
print(preds)
print(labels)
print(err)
assert err < 0.1

0 comments on commit bba6aa7

Please sign in to comment.