Skip to content

Commit

Permalink
Check support status for categorical features. (#9946)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 4, 2024
1 parent db396ee commit c03a4d5
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 40 deletions.
17 changes: 11 additions & 6 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
Expand Down Expand Up @@ -158,15 +158,15 @@ class MetaInfo {
void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;

/*
* \brief Extend with other MetaInfo.
/**
* @brief Extend with other MetaInfo.
*
* \param that The other MetaInfo object.
* @param that The other MetaInfo object.
*
* \param accumulate_rows Whether rows need to be accumulated in this function. If
* @param accumulate_rows Whether rows need to be accumulated in this function. If
* client code knows number of rows in advance, set this
* parameter to false.
* \param check_column Whether the extend method should check the consistency of
* @param check_column Whether the extend method should check the consistency of
* columns.
*/
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
Expand Down Expand Up @@ -203,13 +203,18 @@ class MetaInfo {
* learning where labels are only available on worker 0.
*/
bool ShouldHaveLabels() const;
/**
* @brief Flag for whether the DMatrix has categorical features.
*/
bool HasCategorical() const { return has_categorical_; }

private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);

/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_;
bool has_categorical_{false};
};

/*! \brief Element from a sparse vector */
Expand Down
6 changes: 5 additions & 1 deletion src/common/error_msg.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* \brief Common error message for various checks.
*/
Expand Down Expand Up @@ -99,5 +99,9 @@ constexpr StringView InvalidCUDAOrdinal() {
void MismatchedDevices(Context const* booster, Context const* data);

inline auto NoFederated() { return "XGBoost is not compiled with federated learning support."; }

inline auto NoCategorical(std::string name) {
return name + " doesn't support categorical features.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
45 changes: 35 additions & 10 deletions src/data/data.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023 by XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file data.cc
*/
#include "xgboost/data.h"
Expand Down Expand Up @@ -260,9 +260,14 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
}

void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<FeatureType>* types) {
/**
* @brief Load feature type info from names, returns whether there's categorical features.
*/
[[nodiscard]] bool LoadFeatureType(std::vector<std::string> const& type_names,
std::vector<FeatureType>* types) {
types->clear();
for (auto const &elem : type_names) {
bool has_cat{false};
for (auto const& elem : type_names) {
if (elem == "int") {
types->emplace_back(FeatureType::kNumerical);
} else if (elem == "float") {
Expand All @@ -273,10 +278,12 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
types->emplace_back(FeatureType::kNumerical);
} else if (elem == "c") {
types->emplace_back(FeatureType::kCategorical);
has_cat = true;
} else {
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
}
}
return has_cat;
}

const std::vector<size_t>& MetaInfo::LabelAbsSort(Context const* ctx) const {
Expand Down Expand Up @@ -340,7 +347,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"feature_names", DataType::kStr, &feature_names);
LoadVectorField(fi, u8"feature_types", DataType::kStr, &feature_type_names);
LoadVectorField(fi, u8"feature_weights", DataType::kFloat32, &feature_weights);
LoadFeatureType(feature_type_names, &feature_types.HostVector());

this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
}

template <typename T>
Expand Down Expand Up @@ -639,6 +647,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info);
}

if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear();
for (size_t i = 0; i < size; ++i) {
Expand All @@ -651,7 +660,7 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
<< "Length of " << key << " must be equal to number of columns.";
}
auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(feature_type_names, &h_feature_types);
this->has_categorical_ = LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) {
if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{};
Expand All @@ -674,9 +683,8 @@ void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulon
}
}

void MetaInfo::GetFeatureInfo(const char *field,
std::vector<std::string> *out_str_vecs) const {
auto &str_vecs = *out_str_vecs;
void MetaInfo::GetFeatureInfo(const char* field, std::vector<std::string>* out_str_vecs) const {
auto& str_vecs = *out_str_vecs;
if (!std::strcmp(field, "feature_type")) {
str_vecs.resize(feature_type_names.size());
std::copy(feature_type_names.cbegin(), feature_type_names.cend(), str_vecs.begin());
Expand All @@ -689,6 +697,9 @@ void MetaInfo::GetFeatureInfo(const char *field,
}

void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_column) {
/**
* shape
*/
if (accumulate_rows) {
this->num_row_ += that.num_row_;
}
Expand All @@ -702,6 +713,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
this->num_col_ = that.num_col_;

/**
* info with n_samples
*/
linalg::Stack(&this->labels, that.labels);

this->weights_.SetDevice(that.weights_.Device());
Expand All @@ -715,6 +729,9 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col

linalg::Stack(&this->base_margin_, that.base_margin_);

/**
* group
*/
if (this->group_ptr_.size() == 0) {
this->group_ptr_ = that.group_ptr_;
} else {
Expand All @@ -727,17 +744,25 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
group_ptr.end());
}

/**
* info with n_features
*/
if (!that.feature_names.empty()) {
this->feature_names = that.feature_names;
}

if (!that.feature_type_names.empty()) {
this->feature_type_names = that.feature_type_names;
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
auto& h_feature_types = feature_types.HostVector();
this->has_categorical_ = LoadFeatureType(this->feature_type_names, &h_feature_types);
} else if (!that.feature_types.Empty()) {
// FIXME(jiamingy): https://github.com/dmlc/xgboost/pull/9171/files#r1440188612
this->feature_types.Resize(that.feature_types.Size());
this->feature_types.Copy(that.feature_types);
auto const& ft = this->feature_types.ConstHostVector();
this->has_categorical_ = std::any_of(ft.cbegin(), ft.cend(), common::IsCatOp{});
}

if (!that.feature_weights.Empty()) {
this->feature_weights.Resize(that.feature_weights.Size());
this->feature_weights.SetDevice(that.feature_weights.Device());
Expand Down
2 changes: 1 addition & 1 deletion src/data/iterative_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class IterativeDMatrix : public DMatrix {
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
LOG(FATAL) << "Not implemented for `QuantileDMatrix`.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches(Context const *) override {
Expand Down
26 changes: 13 additions & 13 deletions src/gbm/gblinear.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023, XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file gblinear.cc
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
* the update rule is parallel coordinate descent (shotgun)
Expand All @@ -8,25 +8,24 @@
#include <dmlc/omp.h>
#include <dmlc/parameter.h>

#include <vector>
#include <string>
#include <sstream>
#include <algorithm>
#include <numeric>
#include <sstream>
#include <string>
#include <vector>

#include "../common/common.h"
#include "../common/error_msg.h" // NoCategorical, DeprecatedFunc
#include "../common/threading_utils.h"
#include "../common/timer.h"
#include "gblinear_model.h"
#include "xgboost/gbm.h"
#include "xgboost/json.h"
#include "xgboost/predictor.h"
#include "xgboost/linear_updater.h"
#include "xgboost/logging.h"
#include "xgboost/learner.h"
#include "xgboost/linalg.h"

#include "gblinear_model.h"
#include "../common/timer.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#include "../common/error_msg.h"
#include "xgboost/linear_updater.h"
#include "xgboost/logging.h"
#include "xgboost/predictor.h"

namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
Expand Down Expand Up @@ -145,6 +144,7 @@ class GBLinear : public GradientBooster {
ObjFunction const*) override {
monitor_.Start("DoBoost");

CHECK(!p_fmat->Info().HasCategorical()) << error::NoCategorical("`gblinear`");
model_.LazyInitModel();
this->LazySumWeights(p_fmat);

Expand Down
22 changes: 13 additions & 9 deletions src/tree/updater_colmaker.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file updater_colmaker.cc
* \brief use columnwise update to construct a tree
* \author Tianqi Chen
*/
#include <vector>
#include <cmath>
#include <algorithm>
#include <cmath>
#include <vector>

#include "xgboost/parameter.h"
#include "xgboost/tree_updater.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "param.h"
#include "constraints.h"
#include "../common/error_msg.h" // for NoCategorical
#include "../common/random.h"
#include "constraints.h"
#include "param.h"
#include "split_evaluator.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "xgboost/parameter.h"
#include "xgboost/tree_updater.h"

namespace xgboost::tree {

Expand Down Expand Up @@ -102,6 +103,9 @@ class ColMaker: public TreeUpdater {
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
"support external memory training.";
}
if (dmat->Info().HasCategorical()) {
LOG(FATAL) << error::NoCategorical("Updater `grow_colmaker` or `exact` tree method");
}
this->LazyGetColumnDensity(dmat);
// rescale learning rate according to size of trees
interaction_constraints_.Configure(*param, dmat->Info().num_row_);
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import tempfile
import weakref
from typing import Any, Callable, Dict, List

Expand Down Expand Up @@ -195,3 +197,39 @@ def mock(*args: Any, **kwargs: Any) -> Any:
assert called == 1

xgb.data._proxy_transform = transform


def test_cat_check() -> None:
n_batches = 3
n_features = 2
n_samples_per_batch = 16

batches = []

for i in range(n_batches):
X, y = tm.make_categorical(
n_samples=n_samples_per_batch,
n_features=n_features,
n_categories=3,
onehot=False,
)
batches.append((X, y))

X, y = list(zip(*batches))
it = tm.IteratorForTest(X, y, None, cache=None)
Xy: xgb.DMatrix = xgb.QuantileDMatrix(it, enable_categorical=True)

with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)

Xy = xgb.DMatrix(X[0], y[0], enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"tree_method": "exact"}, Xy)

with tempfile.TemporaryDirectory() as tmpdir:
cache_path = os.path.join(tmpdir, "cache")

it = tm.IteratorForTest(X, y, None, cache=cache_path)
Xy = xgb.DMatrix(it, enable_categorical=True)
with pytest.raises(ValueError, match="categorical features"):
xgb.train({"booster": "gblinear"}, Xy)

0 comments on commit c03a4d5

Please sign in to comment.