Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle np integer in model slice and prediction. #10007

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python-package/xgboost/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@

FloatCompatible = Union[float, np.float32, np.float64]

# typing.SupportsInt is not suitable here since floating point values are convertible to
# integers as well.
Integer = Union[int, np.integer]
IterationRange = Tuple[Integer, Integer]

# callables
FPreProcCallable = Callable

Expand Down
35 changes: 22 additions & 13 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
FeatureInfo,
FeatureNames,
FeatureTypes,
Integer,
IterationRange,
ModelIn,
NumpyOrCupy,
TransformedData,
Expand Down Expand Up @@ -1812,19 +1814,25 @@ def __setstate__(self, state: Dict) -> None:
state["handle"] = handle
self.__dict__.update(state)

def __getitem__(self, val: Union[int, tuple, slice]) -> "Booster":
def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
"""Get a slice of the tree-based model.

.. versionadded:: 1.3.0

"""
if isinstance(val, int):
val = slice(val, val + 1)
# convert to slice for all other types
if isinstance(val, (np.integer, int)):
val = slice(int(val), int(val + 1))
if isinstance(val, type(Ellipsis)):
val = slice(0, 0)
if isinstance(val, tuple):
raise ValueError("Only supports slicing through 1 dimension.")
# All supported types are now slice
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
if not isinstance(val, slice):
msg = _expect((int, slice), type(val))
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
raise TypeError(msg)

if isinstance(val.start, type(Ellipsis)) or val.start is None:
start = 0
else:
Expand Down Expand Up @@ -2246,12 +2254,13 @@ def predict(
pred_interactions: bool = False,
validate_features: bool = True,
training: bool = False,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
strict_shape: bool = False,
) -> np.ndarray:
"""Predict with data. The full model will be used unless `iteration_range` is specified,
meaning user have to either slice the model or use the ``best_iteration``
attribute to get prediction from best model returned from early stopping.
"""Predict with data. The full model will be used unless `iteration_range` is
specified, meaning user have to either slice the model or use the
``best_iteration`` attribute to get prediction from best model returned from
early stopping.

.. note::

Expand Down Expand Up @@ -2336,8 +2345,8 @@ def predict(
args = {
"type": 0,
"training": training,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"iteration_begin": int(iteration_range[0]),
"iteration_end": int(iteration_range[1]),
"strict_shape": strict_shape,
}

Expand Down Expand Up @@ -2373,7 +2382,7 @@ def assign_type(t: int) -> None:
def inplace_predict(
self,
data: DataType,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
predict_type: str = "value",
missing: float = np.nan,
validate_features: bool = True,
Expand Down Expand Up @@ -2439,8 +2448,8 @@ def inplace_predict(
args = make_jcargs(
type=1 if predict_type == "margin" else 0,
training=False,
iteration_begin=iteration_range[0],
iteration_end=iteration_range[1],
iteration_begin=int(iteration_range[0]),
iteration_end=int(iteration_range[1]),
missing=missing,
strict_shape=strict_shape,
cache_id=0,
Expand Down
24 changes: 12 additions & 12 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import numpy

from xgboost import collective, config
from xgboost._typing import _T, FeatureNames, FeatureTypes
from xgboost._typing import _T, FeatureNames, FeatureTypes, IterationRange
from xgboost.callback import TrainingCallback
from xgboost.compat import DataFrame, LazyLoader, concat, lazy_isinstance
from xgboost.core import (
Expand Down Expand Up @@ -1263,7 +1263,7 @@ async def _predict_async(
approx_contribs: bool,
pred_interactions: bool,
validate_features: bool,
iteration_range: Tuple[int, int],
iteration_range: IterationRange,
strict_shape: bool,
) -> _DaskCollection:
_booster = await _get_model_future(client, model)
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def predict( # pylint: disable=unused-argument
approx_contribs: bool = False,
pred_interactions: bool = False,
validate_features: bool = True,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
strict_shape: bool = False,
) -> Any:
"""Run prediction with a trained booster.
Expand Down Expand Up @@ -1458,7 +1458,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DataT,
iteration_range: Tuple[int, int],
iteration_range: IterationRange,
predict_type: str,
missing: float,
validate_features: bool,
Expand Down Expand Up @@ -1516,7 +1516,7 @@ def inplace_predict( # pylint: disable=unused-argument
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DataT,
iteration_range: Tuple[int, int] = (0, 0),
iteration_range: IterationRange = (0, 0),
predict_type: str = "value",
missing: float = numpy.nan,
validate_features: bool = True,
Expand Down Expand Up @@ -1624,7 +1624,7 @@ async def _predict_async(
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
if self._can_use_inplace_predict():
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def predict(
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self.client.sync(
Expand All @@ -1679,7 +1679,7 @@ def predict(
async def _apply_async(
self,
X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(
Expand All @@ -1700,7 +1700,7 @@ async def _apply_async(
def apply(
self,
X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
Expand Down Expand Up @@ -1962,7 +1962,7 @@ async def _predict_proba_async(
X: _DataT,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
if self.objective == "multi:softmax":
raise ValueError(
Expand All @@ -1987,7 +1987,7 @@ def predict_proba(
X: _DaskCollection,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> Any:
_assert_dask_support()
return self._client_sync(
Expand All @@ -2006,7 +2006,7 @@ async def _predict_async(
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
iteration_range: Optional[IterationRange],
) -> _DaskCollection:
pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin, iteration_range
Expand Down
18 changes: 9 additions & 9 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
from scipy.special import softmax

from ._typing import ArrayLike, FeatureNames, FeatureTypes, ModelIn
from ._typing import ArrayLike, FeatureNames, FeatureTypes, IterationRange, ModelIn
from .callback import TrainingCallback

# Do not use class names on scikit-learn directly. Re-define the classes on
Expand Down Expand Up @@ -1039,8 +1039,8 @@ def _can_use_inplace_predict(self) -> bool:
return False

def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]]
) -> Tuple[int, int]:
self, iteration_range: Optional[IterationRange]
) -> IterationRange:
if iteration_range is None or iteration_range[1] == 0:
# Use best_iteration if defined.
try:
Expand All @@ -1057,7 +1057,7 @@ def predict(
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
"""Predict with `X`. If the model is trained with early stopping, then
:py:attr:`best_iteration` is used automatically. The estimator uses
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def predict(
def apply(
self,
X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> np.ndarray:
"""Return the predicted leaf every tree for each sample. If the model is trained
with early stopping, then :py:attr:`best_iteration` is used automatically.
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def predict(
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
with config_context(verbosity=self.verbosity):
class_probs = super().predict(
Expand Down Expand Up @@ -1500,7 +1500,7 @@ def predict_proba(
X: ArrayLike,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> np.ndarray:
"""Predict the probability of each `X` example being of a given class. If the
model is trained with early stopping, then :py:attr:`best_iteration` is used
Expand Down Expand Up @@ -1942,7 +1942,7 @@ def predict(
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().predict(
Expand All @@ -1956,7 +1956,7 @@ def predict(
def apply(
self,
X: ArrayLike,
iteration_range: Optional[Tuple[int, int]] = None,
iteration_range: Optional[IterationRange] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().apply(X, iteration_range)
Expand Down
38 changes: 25 additions & 13 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import Integer
from xgboost.testing.updater import ResetStrategy

dpath = tm.data_dir(__file__)
Expand Down Expand Up @@ -97,15 +98,15 @@ def my_logloss(preds, dtrain):
def test_boost_from_prediction(self):
# Re-construct dtrain here to avoid modification
margined, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_0 = bst.predict(margined, output_margin=True)
margined.set_base_margin(predt_0)
bst = xgb.train({'tree_method': 'hist'}, margined, 1)
bst = xgb.train({"tree_method": "hist"}, margined, 1)
predt_1 = bst.predict(margined)

assert np.any(np.abs(predt_1 - predt_0) > 1e-6)
dtrain, _ = tm.load_agaricus(__file__)
bst = xgb.train({'tree_method': 'hist'}, dtrain, 2)
bst = xgb.train({"tree_method": "hist"}, dtrain, 2)
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)

Expand Down Expand Up @@ -331,10 +332,15 @@ def run_slice(
dtrain: xgb.DMatrix,
num_parallel_tree: int,
num_classes: int,
num_boost_round: int
num_boost_round: int,
use_np_type: bool,
):
beg = 3
end = 7
if use_np_type:
end: Integer = np.int32(7)
else:
end = 7

sliced: xgb.Booster = booster[beg:end]
assert sliced.feature_types == booster.feature_types

Expand All @@ -345,7 +351,7 @@ def run_slice(
sliced = booster[beg:end:2]
assert sliced_trees == len(sliced.get_dump())

sliced = booster[beg: ...]
sliced = booster[beg:]
sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())

Expand All @@ -357,7 +363,7 @@ def run_slice(
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())

sliced = booster[...: end]
sliced = booster[: end]
sliced_trees = end * num_parallel_tree * num_classes
assert sliced_trees == len(sliced.get_dump())

Expand All @@ -383,14 +389,14 @@ def run_slice(
assert len(trees) == num_boost_round

with pytest.raises(TypeError):
booster["wrong type"]
booster["wrong type"] # type: ignore
with pytest.raises(IndexError):
booster[: num_boost_round + 1]
with pytest.raises(ValueError):
booster[1, 2] # too many dims
# setitem is not implemented as model is immutable during slicing.
with pytest.raises(TypeError):
booster[...: end] = booster
booster[:end] = booster # type: ignore

sliced_0 = booster[1:3]
np.testing.assert_allclose(
Expand Down Expand Up @@ -446,15 +452,21 @@ def test_slice(self, booster):

assert len(booster.get_dump()) == total_trees

self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)

bytesarray = booster.save_raw(raw_format="ubj")
booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)

bytesarray = booster.save_raw(raw_format="deprecated")
booster = xgb.Booster(model_file=bytesarray)
self.run_slice(booster, dtrain, num_parallel_tree, num_classes, num_boost_round)
self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, True
)

def test_slice_multi(self) -> None:
from sklearn.datasets import make_classification
Expand All @@ -479,7 +491,7 @@ def test_slice_multi(self) -> None:
},
num_boost_round=num_boost_round,
dtrain=Xy,
callbacks=[ResetStrategy()]
callbacks=[ResetStrategy()],
)
sliced = [t for t in booster]
assert len(sliced) == 16
Expand Down
Loading
Loading