Skip to content

Commit

Permalink
[R] Add missing DMatrix functions (#9929)
Browse files Browse the repository at this point in the history
* `XGDMatrixGetQuantileCut`
* `XGDMatrixNumNonMissing`
* `XGDMatrixGetDataAsCSR`

---------

Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
david-cortes and trivialfis authored Jan 3, 2024
1 parent 4924745 commit 3c004a4
Show file tree
Hide file tree
Showing 14 changed files with 438 additions and 9 deletions.
5 changes: 5 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
export(xgb.gblinear.history)
export(xgb.get.DMatrix.data)
export(xgb.get.DMatrix.num.non.missing)
export(xgb.get.DMatrix.qcut)
export(xgb.get.config)
export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
Expand All @@ -60,6 +63,7 @@ export(xgb.unserialize)
export(xgboost)
import(methods)
importClassesFrom(Matrix,dgCMatrix)
importClassesFrom(Matrix,dgRMatrix)
importClassesFrom(Matrix,dgeMatrix)
importFrom(Matrix,colSums)
importFrom(Matrix,sparse.model.matrix)
Expand All @@ -83,6 +87,7 @@ importFrom(graphics,points)
importFrom(graphics,title)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,toJSON)
importFrom(methods,new)
importFrom(stats,median)
importFrom(stats,predict)
importFrom(utils,head)
Expand Down
105 changes: 105 additions & 0 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,111 @@ setinfo.xgb.DMatrix <- function(object, name, info) {
stop("setinfo: unknown info name ", name)
}

#' @title Get Quantile Cuts from DMatrix
#' @description Get the quantile cuts (a.k.a. borders) from an `xgb.DMatrix`
#' that has been quantized for the histogram method (`tree_method="hist"`).
#'
#' These cuts are used in order to assign observations to bins - i.e. these are ordered
#' boundaries which are used to determine assignment condition `border_low < x < border_high`.
#' As such, the first and last bin will be outside of the range of the data, so as to include
#' all of the observations there.
#'
#' If a given column has 'n' bins, then there will be 'n+1' cuts / borders for that column,
#' which will be output in sorted order from lowest to highest.
#'
#' Different columns can have different numbers of bins according to their range.
#' @param dmat An `xgb.DMatrix` object, as returned by \link{xgb.DMatrix}.
#' @param output Output format for the quantile cuts. Possible options are:\itemize{
#' \item `"list"` will return the output as a list with one entry per column, where
#' each column will have a numeric vector with the cuts. The list will be named if
#' `dmat` has column names assigned to it.
#' \item `"arrays"` will return a list with entries `indptr` (base-0 indexing) and
#' `data`. Here, the cuts for column 'i' are obtained by slicing 'data' from entries
#' `indptr[i]+1` to `indptr[i+1]`.
#' }
#' @return The quantile cuts, in the format specified by parameter `output`.
#' @examples
#' library(xgboost)
#' data(mtcars)
#' y <- mtcars$mpg
#' x <- as.matrix(mtcars[, -1])
#' dm <- xgb.DMatrix(x, label = y, nthread = 1)
#'
#' # DMatrix is not quantized right away, but will be once a hist model is generated
#' model <- xgb.train(
#' data = dm,
#' params = list(
#' tree_method = "hist",
#' max_bin = 8,
#' nthread = 1
#' ),
#' nrounds = 3
#' )
#'
#' # Now can get the quantile cuts
#' xgb.get.DMatrix.qcut(dm)
#' @export
xgb.get.DMatrix.qcut <- function(dmat, output = c("list", "arrays")) { # nolint
stopifnot(inherits(dmat, "xgb.DMatrix"))
output <- head(output, 1L)
stopifnot(output %in% c("list", "arrays"))
res <- .Call(XGDMatrixGetQuantileCut_R, dmat)
if (output == "arrays") {
return(res)
} else {
feature_names <- getinfo(dmat, "feature_name")
ncols <- length(res$indptr) - 1
out <- lapply(
seq(1, ncols),
function(col) {
st <- res$indptr[col]
end <- res$indptr[col + 1]
if (end <= st) {
return(numeric())
}
return(res$data[seq(1 + st, end)])
}
)
if (NROW(feature_names)) {
names(out) <- feature_names
}
return(out)
}
}

#' @title Get Number of Non-Missing Entries in DMatrix
#' @param dmat An `xgb.DMatrix` object, as returned by \link{xgb.DMatrix}.
#' @return The number of non-missing entries in the DMatrix
#' @export
xgb.get.DMatrix.num.non.missing <- function(dmat) { # nolint
stopifnot(inherits(dmat, "xgb.DMatrix"))
return(.Call(XGDMatrixNumNonMissing_R, dmat))
}

#' @title Get DMatrix Data
#' @param dmat An `xgb.DMatrix` object, as returned by \link{xgb.DMatrix}.
#' @return The data held in the DMatrix, as a sparse CSR matrix (class `dgRMatrix`
#' from package `Matrix`). If it had feature names, these will be added as column names
#' in the output.
#' @export
xgb.get.DMatrix.data <- function(dmat) {
stopifnot(inherits(dmat, "xgb.DMatrix"))
res <- .Call(XGDMatrixGetDataAsCSR_R, dmat)
out <- methods::new("dgRMatrix")
nrows <- as.integer(length(res$indptr) - 1)
out@p <- res$indptr
out@j <- res$indices
out@x <- res$data
out@Dim <- as.integer(c(nrows, res$ncols))

feature_names <- getinfo(dmat, "feature_name")
dim_names <- list(NULL, NULL)
if (NROW(feature_names)) {
dim_names[[2L]] <- feature_names
}
out@Dimnames <- dim_names
return(out)
}

#' Get a new DMatrix containing the specified rows of
#' original xgb.DMatrix object
Expand Down
3 changes: 2 additions & 1 deletion R-package/R/xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ NULL
NULL

# Various imports
#' @importClassesFrom Matrix dgCMatrix dgeMatrix
#' @importClassesFrom Matrix dgCMatrix dgeMatrix dgRMatrix
#' @importFrom Matrix colSums
#' @importFrom Matrix sparse.model.matrix
#' @importFrom Matrix sparseVector
Expand All @@ -98,6 +98,7 @@ NULL
#' @importFrom data.table setnames
#' @importFrom jsonlite fromJSON
#' @importFrom jsonlite toJSON
#' @importFrom methods new
#' @importFrom utils object.size str tail
#' @importFrom stats predict
#' @importFrom stats median
Expand Down
19 changes: 19 additions & 0 deletions R-package/man/xgb.get.DMatrix.data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions R-package/man/xgb.get.DMatrix.num.non.missing.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions R-package/man/xgb.get.DMatrix.qcut.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ OBJECTS= \
$(PKGROOT)/src/gbm/gblinear.o \
$(PKGROOT)/src/gbm/gblinear_model.o \
$(PKGROOT)/src/data/adapter.o \
$(PKGROOT)/src/data/array_interface.o \
$(PKGROOT)/src/data/simple_dmatrix.o \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ OBJECTS= \
$(PKGROOT)/src/gbm/gblinear.o \
$(PKGROOT)/src/gbm/gblinear_model.o \
$(PKGROOT)/src/data/adapter.o \
$(PKGROOT)/src/data/array_interface.o \
$(PKGROOT)/src/data/simple_dmatrix.o \
$(PKGROOT)/src/data/data.o \
$(PKGROOT)/src/data/sparse_page_raw_format.o \
Expand Down
6 changes: 6 additions & 0 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ extern SEXP XGDMatrixCreateFromDF_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
extern SEXP XGDMatrixNumCol_R(SEXP);
extern SEXP XGDMatrixNumRow_R(SEXP);
extern SEXP XGDMatrixGetQuantileCut_R(SEXP);
extern SEXP XGDMatrixNumNonMissing_R(SEXP);
extern SEXP XGDMatrixGetDataAsCSR_R(SEXP);
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
Expand Down Expand Up @@ -84,6 +87,9 @@ static const R_CallMethodDef CallEntries[] = {
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
{"XGDMatrixGetQuantileCut_R", (DL_FUNC) &XGDMatrixGetQuantileCut_R, 1},
{"XGDMatrixNumNonMissing_R", (DL_FUNC) &XGDMatrixNumNonMissing_R, 1},
{"XGDMatrixGetDataAsCSR_R", (DL_FUNC) &XGDMatrixGetDataAsCSR_R, 1},
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
Expand Down
Loading

0 comments on commit 3c004a4

Please sign in to comment.