Skip to content

Commit

Permalink
replace misused num_class
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Dec 29, 2023
1 parent cada62e commit 35d6c7c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
15 changes: 8 additions & 7 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
#' \code{data},
#' \code{end_iteration},
#' \code{params},
#' \code{num_class}.
#'
#' @return
#' Predictions are returned inside of the \code{pred} element, which is either a vector or a matrix,
Expand All @@ -488,19 +487,21 @@ cb.cv.predict <- function(save_models = FALSE) {
stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame")

N <- nrow(env$data)
pred <-
if (env$num_class > 1) {
matrix(NA_real_, N, env$num_class)
} else {
rep(NA_real_, N)
}
pred <- NULL

iterationrange <- c(1, NVL(env$basket$best_iteration, env$end_iteration))
if (NVL(env$params[['booster']], '') == 'gblinear') {
iterationrange <- "all"
}
for (fd in env$bst_folds) {
pr <- predict(fd$bst, fd$watchlist[[2]], iterationrange = iterationrange, reshape = TRUE)
if (is.null(pred)) {
if (NCOL(pr) > 1L) {
pred <- matrix(NA_real_, N, ncol(pr))
} else {
pred <- matrix(NA_real_, N)
}
}
if (is.matrix(pred)) {
pred[fd$index, ] <- pr
} else {
Expand Down
1 change: 0 additions & 1 deletion R-package/man/cb.cv.predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions R-package/tests/testthat/test_callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,44 @@ test_that("prediction in xgb.cv for softprob works", {
expect_equal(dim(cv$pred), c(nrow(iris), 3))
expect_lt(diff(range(rowSums(cv$pred))), 1e-6)
})

test_that("prediction in xgb.cv works for multi-quantile", {
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(x, label = y, nthread = 1)
cv <- xgb.cv(
data = dm,
params = list(
objective = "reg:quantileerror",
quantile_alpha = c(0.1, 0.2, 0.5, 0.8, 0.9),
nthread = 1
),
nrounds = 5,
nfold = 3,
prediction = TRUE,
verbose = 0
)
expect_equal(dim(cv$pred), c(nrow(x), 5))
})

test_that("prediction in xgb.cv works for multi-output", {
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(x, label = cbind(y, -y), nthread = 1)
cv <- xgb.cv(
data = dm,
params = list(
tree_method = "hist",
multi_strategy = "multi_output_tree",
objective = "reg:squarederror",
nthread = n_threads
),
nrounds = 5,
nfold = 3,
prediction = TRUE,
verbose = 0
)
expect_equal(dim(cv$pred), c(nrow(x), 2))
})

0 comments on commit 35d6c7c

Please sign in to comment.