Skip to content

Commit

Permalink
Fix multi-output with alternating strategies. (#9933)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Philip Hyunsu Cho <[email protected]>
  • Loading branch information
trivialfis and hcho3 authored Jan 4, 2024
1 parent 5f7b5a6 commit 621348a
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 73 deletions.
11 changes: 11 additions & 0 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,14 @@ def train_result(
assert booster.feature_types == dmat.feature_types

return result


class ResetStrategy(xgb.callback.TrainingCallback):
"""Callback for testing multi-output."""

def after_iteration(self, model: xgb.Booster, epoch: int, evals_log: dict) -> bool:
if epoch % 2 == 0:
model.set_param({"multi_strategy": "multi_output_tree"})
else:
model.set_param({"multi_strategy": "one_output_per_tree"})
return False
8 changes: 4 additions & 4 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,12 @@ class QuantileHistMaker : public TreeUpdater {
}

bool UpdatePredictionCache(const DMatrix *data, linalg::MatrixView<float> out_preds) override {
if (p_impl_) {
return p_impl_->UpdatePredictionCache(data, out_preds);
} else if (p_mtimpl_) {
if (out_preds.Shape(1) > 1) {
CHECK(p_mtimpl_);
return p_mtimpl_->UpdatePredictionCache(data, out_preds);
} else {
return false;
CHECK(p_impl_);
return p_impl_->UpdatePredictionCache(data, out_preds);
}
}

Expand Down
2 changes: 2 additions & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LintersPaths:
"tests/python/test_dmatrix.py",
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_multi_target.py",
"tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py",
"tests/python/test_tree_regularization.py",
Expand Down Expand Up @@ -79,6 +80,7 @@ class LintersPaths:
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_data_iterator.py",
"tests/python/test_multi_target.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/load_pickle.py",
"tests/test_distributed/test_with_spark/test_data.py",
Expand Down
6 changes: 1 addition & 5 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.updater import ResetStrategy

dpath = tm.data_dir(__file__)

Expand Down Expand Up @@ -653,11 +654,6 @@ def test_slice_multi(self) -> None:
num_parallel_tree = 4
num_boost_round = 16

class ResetStrategy(xgb.callback.TrainingCallback):
def after_iteration(self, model, epoch: int, evals_log) -> bool:
model.set_param({"multi_strategy": "multi_output_tree"})
return False

booster = xgb.train(
{
"num_parallel_tree": num_parallel_tree,
Expand Down
105 changes: 105 additions & 0 deletions tests/python/test_multi_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any, Dict

from hypothesis import given, note, settings, strategies

import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.params import (
exact_parameter_strategy,
hist_cache_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import ResetStrategy, train_result


class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])

@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])

@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])


def test_multiclass() -> None:
X, y = tm.datasets.make_classification(
128, n_features=12, n_informative=10, n_classes=4
)
clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
)
clf.fit(X, y, eval_set=[(X, y)])
assert clf.objective == "multi:softprob"
assert tm.non_increasing(clf.evals_result()["validation_0"]["mlogloss"])

proba = clf.predict_proba(X)
assert proba.shape == (y.shape[0], 4)


def test_multilabel() -> None:
X, y = tm.datasets.make_multilabel_classification(128)
clf = xgb.XGBClassifier(
multi_strategy="multi_output_tree", callbacks=[ResetStrategy()], n_estimators=10
)
clf.fit(X, y, eval_set=[(X, y)])
assert clf.objective == "binary:logistic"
assert tm.non_increasing(clf.evals_result()["validation_0"]["logloss"])

proba = clf.predict_proba(X)
assert proba.shape == y.shape
64 changes: 0 additions & 64 deletions tests/python/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
cat_parameter_strategy,
exact_parameter_strategy,
hist_cache_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import (
Expand All @@ -25,69 +24,6 @@
)


class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])

@given(
exact_parameter_strategy,
hist_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(
self, param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])

@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_cache_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self,
param: Dict[str, Any],
hist_param: Dict[str, Any],
cache_param: Dict[str, Any],
num_rounds: int,
dataset: tm.TestDataset,
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
param.update(cache_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(str(result))
assert tm.non_increasing(result["train"][dataset.metric])


class TestTreeMethod:
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
Expand Down

0 comments on commit 621348a

Please sign in to comment.