diff --git a/src/common/common.h b/src/common/common.h index 4b20ce7c2156..950dee5210b1 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -66,8 +66,20 @@ inline std::vector Split(const std::string& s, char delim) { return ret; } +/** + * @brief Add escapes for a UTF-8 string. + */ void EscapeU8(std::string const &string, std::string *p_buffer); +/** + * @brief Add escapes for a UTF-8 string with newly created buffer as return. + */ +inline std::string EscapeU8(std::string const &str) { + std::string buffer; + EscapeU8(str, &buffer); + return buffer; +} + template XGBOOST_DEVICE T Max(T a, T b) { return a < b ? b : a; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index d37be14b894d..f18b519264a0 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by Contributors + * Copyright 2015-2023, XGBoost Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -15,9 +15,9 @@ #include #include "../common/categorical.h" -#include "../common/common.h" +#include "../common/common.h" // for EscapeU8 #include "../predictor/predict_fn.h" -#include "io_utils.h" // GetElem +#include "io_utils.h" // for GetElem #include "param.h" #include "xgboost/base.h" #include "xgboost/data.h" @@ -207,8 +207,9 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& __make_ ## TreeGenReg ## _ ## UniqueId ## __ = \ ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name) -std::vector GetSplitCategories(RegTree const &tree, int32_t nidx) { - auto const &csr = tree.GetCategoriesMatrix(); +namespace { +std::vector GetSplitCategories(RegTree const& tree, int32_t nidx) { + auto const& csr = tree.GetCategoriesMatrix(); auto seg = csr.node_ptr[nidx]; auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)}; @@ -221,7 +222,7 @@ std::vector GetSplitCategories(RegTree const &tree, int32_t nidx) { return cats; } -std::string PrintCatsAsSet(std::vector const &cats) { +std::string PrintCatsAsSet(std::vector const& cats) { std::stringstream ss; ss << "{"; for (size_t i = 0; i < cats.size(); ++i) { @@ -234,6 +235,15 @@ std::string PrintCatsAsSet(std::vector const &cats) { return ss.str(); } +std::string GetFeatureName(FeatureMap const& fmap, bst_feature_t split_index) { + CHECK_LE(fmap.Size(), std::numeric_limits::max()); + auto fname = split_index < static_cast(fmap.Size()) + ? fmap.Name(split_index) + : ('f' + std::to_string(split_index)); + return common::EscapeU8(fname); +} +} // anonymous namespace + class TextGenerator : public TreeGenerator { using SuperT = TreeGenerator; @@ -263,7 +273,7 @@ class TextGenerator : public TreeGenerator { std::string result = SuperT::Match( kIndicatorTemplate, {{"{nid}", std::to_string(nid)}, - {"{fname}", fmap_.Name(split_index)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, {"{yes}", std::to_string(nyes)}, {"{no}", std::to_string(tree[nid].DefaultChild())}}); return result; @@ -277,8 +287,7 @@ class TextGenerator : public TreeGenerator { template_str, {{"{tabs}", SuperT::Tabs(depth)}, {"{nid}", std::to_string(nid)}, - {"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) : - std::to_string(split_index)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, {"{cond}", cond}, {"{left}", std::to_string(tree[nid].LeftChild())}, {"{right}", std::to_string(tree[nid].RightChild())}, @@ -308,7 +317,7 @@ class TextGenerator : public TreeGenerator { std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = - "{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}"; + "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); } @@ -376,7 +385,7 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { + std::string LeafNode(RegTree const& tree, bst_node_t nid, uint32_t) const override { static std::string const kLeafTemplate = R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L"; static std::string const kStatTemplate = @@ -392,26 +401,22 @@ class JsonGenerator : public TreeGenerator { return result; } - std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) const override { + std::string Indicator(RegTree const& tree, bst_node_t nid, uint32_t depth) const override { int32_t nyes = tree[nid].DefaultLeft() ? tree[nid].RightChild() : tree[nid].LeftChild(); static std::string const kIndicatorTemplate = R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID"; auto split_index = tree[nid].SplitIndex(); - auto fname = fmap_.Name(split_index); - std::string qfname; // quoted - common::EscapeU8(fname, &qfname); - auto result = SuperT::Match( - kIndicatorTemplate, - {{"{nid}", std::to_string(nid)}, - {"{depth}", std::to_string(depth)}, - {"{fname}", qfname}, - {"{yes}", std::to_string(nyes)}, - {"{no}", std::to_string(tree[nid].DefaultChild())}}); + auto result = + SuperT::Match(kIndicatorTemplate, {{"{nid}", std::to_string(nid)}, + {"{depth}", std::to_string(depth)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{yes}", std::to_string(nyes)}, + {"{no}", std::to_string(tree[nid].DefaultChild())}}); return result; } - std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override { + std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override { auto cats = GetSplitCategories(tree, nid); static std::string const kCategoryTemplate = R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" @@ -429,22 +434,17 @@ class JsonGenerator : public TreeGenerator { return results; } - std::string SplitNodeImpl(RegTree const &tree, int32_t nid, - std::string const &template_str, std::string cond, - uint32_t depth) const { + std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str, + std::string cond, uint32_t depth) const { auto split_index = tree[nid].SplitIndex(); - auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index); - std::string qfname; // quoted - common::EscapeU8(fname, &qfname); - std::string const result = SuperT::Match( - template_str, - {{"{nid}", std::to_string(nid)}, - {"{depth}", std::to_string(depth)}, - {"{fname}", qfname}, - {"{cond}", cond}, - {"{left}", std::to_string(tree[nid].LeftChild())}, - {"{right}", std::to_string(tree[nid].RightChild())}, - {"{missing}", std::to_string(tree[nid].DefaultChild())}}); + std::string const result = + SuperT::Match(template_str, {{"{nid}", std::to_string(nid)}, + {"{depth}", std::to_string(depth)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{cond}", cond}, + {"{left}", std::to_string(tree[nid].LeftChild())}, + {"{right}", std::to_string(tree[nid].RightChild())}, + {"{missing}", std::to_string(tree[nid].DefaultChild())}}); return result; } @@ -605,9 +605,8 @@ class GraphvizGenerator : public TreeGenerator { auto const& extra = kwargs["graph_attrs"]; static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n"; for (auto const& kv : extra) { - param_.graph_attrs += SuperT::Match(kGraphTemplate, - {{"{key}", kv.first}, - {"{value}", kv.second}}); + param_.graph_attrs += + SuperT::Match(kGraphTemplate, {{"{key}", kv.first}, {"{value}", kv.second}}); } kwargs.erase("graph_attrs"); @@ -646,20 +645,18 @@ class GraphvizGenerator : public TreeGenerator { // Only indicator is different, so we combine all different node types into this // function. std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { - auto split = tree[nid].SplitIndex(); + auto split_index = tree[nid].SplitIndex(); auto cond = tree[nid].SplitCond(); - static std::string const kNodeTemplate = - " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; - - // Indicator only has fname. - bool has_less = (split >= fmap_.Size()) || fmap_.TypeOf(split) != FeatureMap::kIndicator; - std::string result = SuperT::Match(kNodeTemplate, { - {"{nid}", std::to_string(nid)}, - {"{fname}", split < fmap_.Size() ? fmap_.Name(split) : - 'f' + std::to_string(split)}, - {"{<}", has_less ? "<" : ""}, - {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, - {"{params}", param_.condition_node_params}}); + static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; + + bool has_less = + (split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator; + std::string result = + SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{<}", has_less ? "<" : ""}, + {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, + {"{params}", param_.condition_node_params}}); result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); result += BuildEdge(tree, nid, tree[nid].RightChild(), false); @@ -672,14 +669,13 @@ class GraphvizGenerator : public TreeGenerator { " {nid} [ label=\"{fname}:{cond}\" {params}]\n"; auto cats = GetSplitCategories(tree, nid); auto cats_str = PrintCatsAsSet(cats); - auto split = tree[nid].SplitIndex(); - std::string result = SuperT::Match( - kLabelTemplate, - {{"{nid}", std::to_string(nid)}, - {"{fname}", split < fmap_.Size() ? fmap_.Name(split) - : 'f' + std::to_string(split)}, - {"{cond}", cats_str}, - {"{params}", param_.condition_node_params}}); + auto split_index = tree[nid].SplitIndex(); + + std::string result = + SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)}, + {"{fname}", GetFeatureName(fmap_, split_index)}, + {"{cond}", cats_str}, + {"{params}", param_.condition_node_params}}); result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); result += BuildEdge(tree, nid, tree[nid].RightChild(), false); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index 44708ebd1163..2dc1893dd645 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -404,7 +404,7 @@ TEST(Tree, DumpText) { } ASSERT_EQ(n_conditions, 3ul); - ASSERT_NE(str.find("[f0<0]"), std::string::npos); + ASSERT_NE(str.find("[f0<0]"), std::string::npos) << str; ASSERT_NE(str.find("[f1<1]"), std::string::npos); ASSERT_NE(str.find("[f2<2]"), std::string::npos); diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index f0c80124d905..45bef1f25f5c 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -28,10 +28,11 @@ def json_model(model_path: str, parameters: dict) -> dict: if model_path.endswith("ubj"): import ubjson + with open(model_path, "rb") as ubjfd: model = ubjson.load(ubjfd) else: - with open(model_path, 'r') as fd: + with open(model_path, "r") as fd: model = json.load(fd) return model @@ -439,25 +440,34 @@ def validate_model(parameters): 'objective': 'multi:softmax'} validate_model(parameters) - def test_special_model_dump_characters(self): + def test_special_model_dump_characters(self) -> None: params = {"objective": "reg:squarederror", "max_depth": 3} - feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"] + feature_names = ['"feature 0"', "\tfeature\n1", """feature "2"."""] X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False) Xy = xgb.DMatrix(X, label=y, feature_names=feature_names) booster = xgb.train(params, Xy, num_boost_round=3) + json_dump = booster.get_dump(dump_format="json") assert len(json_dump) == 3 - def validate(obj: dict) -> None: + def validate_json(obj: dict) -> None: for k, v in obj.items(): if k == "split": assert v in feature_names elif isinstance(v, dict): - validate(v) + validate_json(v) for j_tree in json_dump: loaded = json.loads(j_tree) - validate(loaded) + validate_json(loaded) + + dot_dump = booster.get_dump(dump_format="dot") + for d in dot_dump: + assert d.find(r"feature \"2\"") != -1 + + text_dump = booster.get_dump(dump_format="text") + for d in text_dump: + assert d.find(r"feature \"2\"") != -1 def test_categorical_model_io(self): X, y = tm.make_categorical(256, 16, 71, False) @@ -485,6 +495,7 @@ def test_categorical_model_io(self): @pytest.mark.skipif(**tm.no_sklearn()) def test_attributes(self): from sklearn.datasets import load_iris + X, y = load_iris(return_X_y=True) cls = xgb.XGBClassifier(n_estimators=2) cls.fit(X, y, early_stopping_rounds=1, eval_set=[(X, y)]) @@ -674,6 +685,7 @@ def after_iteration(self, model, epoch: int, evals_log) -> bool: @pytest.mark.skipif(**tm.no_pandas()) def test_feature_info(self): import pandas as pd + rows = 100 cols = 10 X = rng.randn(rows, cols)