diff --git a/dev/change_scala_version.py b/dev/change_scala_version.py
index d9438f76adf7..b83265f8c5d1 100644
--- a/dev/change_scala_version.py
+++ b/dev/change_scala_version.py
@@ -62,6 +62,17 @@ def main(args):
)
if nsubs > 0:
replaced_scala_binver = True
+ # Replace the final name of shaded jar
+ if "
LabeledPointGroupIterator
organizes data in a tuple format:
- * (isFistGroup || isLastGroup, Array[XGBLabeledPoint]).
- * The edge groups across partitions can be stitched together later.
- * @param base collection of XGBLabeledPoint
- */
-private[spark] class LabeledPointGroupIterator(base: Iterator[XGBLabeledPoint])
- extends AbstractIterator[XGBLabeledPointGroup] {
-
- private var firstPointOfNextGroup: XGBLabeledPoint = null
- private var isNewGroup = false
-
- override def hasNext: Boolean = {
- base.hasNext || isNewGroup
- }
-
- override def next(): XGBLabeledPointGroup = {
- val builder = mutable.ArrayBuilder.make[XGBLabeledPoint]
- var isFirstGroup = true
- if (firstPointOfNextGroup != null) {
- builder += firstPointOfNextGroup
- isFirstGroup = false
- }
-
- isNewGroup = false
- while (!isNewGroup && base.hasNext) {
- val point = base.next()
- val groupId = if (firstPointOfNextGroup != null) firstPointOfNextGroup.group else point.group
- firstPointOfNextGroup = point
- if (point.group == groupId) {
- // add to current group
- builder += point
- } else {
- // start a new group
- isNewGroup = true
- }
- }
-
- val isLastGroup = !isNewGroup
- val result = builder.result()
- val group = XGBLabeledPointGroup(result(0).group, result, isFirstGroup || isLastGroup)
-
- group
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoostProvider.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoostProvider.scala
deleted file mode 100644
index 4c4dbdec1e53..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/PreXGBoostProvider.scala
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- Copyright (c) 2021-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import ml.dmlc.xgboost4j.scala.spark.params.XGBoostEstimatorCommon
-
-import org.apache.spark.ml.{Estimator, Model, PipelineStage}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, Dataset}
-
-/**
- * PreXGBoost implementation provider
- */
-private[scala] trait PreXGBoostProvider {
-
- /**
- * Whether the provider is enabled or not
- * @param dataset the input dataset
- * @return Boolean
- */
- def providerEnabled(dataset: Option[Dataset[_]]): Boolean = false
-
- /**
- * Transform schema
- * @param xgboostEstimator supporting XGBoostClassifier/XGBoostClassificationModel and
- * XGBoostRegressor/XGBoostRegressionModel
- * @param schema the input schema
- * @return the transformed schema
- */
- def transformSchema(xgboostEstimator: XGBoostEstimatorCommon, schema: StructType): StructType
-
- /**
- * Convert the Dataset[_] to RDD[() => Watches] which will be fed to XGBoost
- *
- * @param estimator supports XGBoostClassifier and XGBoostRegressor
- * @param dataset the training data
- * @param params all user defined and defaulted params
- * @return [[XGBoostExecutionParams]] => (RDD[[() => Watches]], Option[ RDD[_] ])
- * RDD[() => Watches] will be used as the training input to build DMatrix
- * Option[ RDD[_] ] is the optional cached RDD
- */
- def buildDatasetToRDD(
- estimator: Estimator[_],
- dataset: Dataset[_],
- params: Map[String, Any]):
- XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]])
-
- /**
- * Transform Dataset
- *
- * @param model supporting [[XGBoostClassificationModel]] and [[XGBoostRegressionModel]]
- * @param dataset the input Dataset to transform
- * @return the transformed DataFrame
- */
- def transformDataset(model: Model[_], dataset: Dataset[_]): DataFrame
-
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala
similarity index 54%
rename from jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala
rename to jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala
index 710dd9adcc1a..cae44ab9aef1 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/Utils.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/Utils.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -14,12 +14,49 @@
limitations under the License.
*/
-package ml.dmlc.xgboost4j.scala.spark.util
+package ml.dmlc.xgboost4j.scala.spark
+import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.json4s.{DefaultFormats, FullTypeHints, JField, JValue, NoTypeHints, TypeHints}
-// based on org.apache.spark.util copy /paste
-object Utils {
+import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
+
+private[scala] object Utils {
+
+ private[spark] implicit class XGBLabeledPointFeatures(
+ val labeledPoint: XGBLabeledPoint
+ ) extends AnyVal {
+ /** Converts the point to [[MLLabeledPoint]]. */
+ private[spark] def asML: MLLabeledPoint = {
+ MLLabeledPoint(labeledPoint.label, labeledPoint.features)
+ }
+
+ /**
+ * Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
+ */
+ def features: Vector = if (labeledPoint.indices == null) {
+ Vectors.dense(labeledPoint.values.map(_.toDouble))
+ } else {
+ Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
+ }
+ }
+
+ private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
+ /**
+ * Converts a [[Vector]] to a data point with a dummy label.
+ *
+ * This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
+ * for prediction.
+ */
+ // TODO support sparsevector
+ def asXGB: XGBLabeledPoint = v match {
+ case v: DenseVector =>
+ XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
+ case v: SparseVector =>
+ XGBLabeledPoint(0.0f, v.size, v.indices, v.toDense.values.map(_.toFloat))
+ }
+ }
def getSparkClassLoader: ClassLoader = getClass.getClassLoader
@@ -27,6 +64,7 @@ object Utils {
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
// scalastyle:off classforname
+
/** Preferred alternative to Class.forName(className) */
def classForName(className: String): Class[_] = {
Class.forName(className, true, getContextOrSparkClassLoader)
@@ -35,9 +73,10 @@ object Utils {
/**
* Get the TypeHints according to the value
+ *
* @param value the instance of class to be serialized
* @return if value is null,
- * return NoTypeHints
+ * return NoTypeHints
* else return the FullTypeHints.
*
* The FullTypeHints will save the full class name into the "jsonClass" of the json,
@@ -53,6 +92,7 @@ object Utils {
/**
* Get the TypeHints according to the saved jsonClass field
+ *
* @param json
* @return TypeHints
*/
@@ -68,4 +108,17 @@ object Utils {
FullTypeHints(List(Utils.classForName(className)))
}.getOrElse(NoTypeHints)
}
+
+ val TRAIN_NAME = "train"
+ val VALIDATION_NAME = "eval"
+
+
+ /** Executes the provided code block and then closes the resource */
+ def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
+ try {
+ block(r)
+ } finally {
+ r.close()
+ }
+ }
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
index 10c4b5a72992..baf579d779ec 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
@@ -18,227 +18,30 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
-import scala.collection.mutable
-import scala.util.Random
-import scala.collection.JavaConverters._
-
-import ml.dmlc.xgboost4j.java.{Communicator, ITracker, XGBoostError, RabitTracker}
-import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager
-import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
-import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.fs.FileSystem
-
+import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.{ResourceProfileBuilder, TaskResourceRequests}
-import org.apache.spark.{SparkConf, SparkContext, TaskContext}
-import org.apache.spark.sql.SparkSession
-
-/**
- * Rabit tracker configurations.
- *
- * @param timeout The number of seconds before timeout waiting for workers to connect. and
- * for the tracker to shutdown.
- * @param hostIp The Rabit Tracker host IP address.
- * This is only needed if the host IP cannot be automatically guessed.
- * @param port The port number for the tracker to listen to. Use a system allocated one by
- * default.
- */
-case class TrackerConf(timeout: Int, hostIp: String = "", port: Int = 0)
-object TrackerConf {
- def apply(): TrackerConf = TrackerConf(0)
-}
-
-private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long)
+import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker, XGBoostError}
+import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
-private[scala] case class XGBoostExecutionParams(
+private[spark] case class RuntimeParams(
numWorkers: Int,
numRounds: Int,
- useExternalMemory: Boolean,
- obj: ObjectiveTrait,
- eval: EvalTrait,
- missing: Float,
- allowNonZeroForMissing: Boolean,
trackerConf: TrackerConf,
- checkpointParam: Option[ExternalCheckpointParams],
- xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingRounds: Int,
- cacheTrainingSet: Boolean,
- device: Option[String],
+ device: String,
isLocal: Boolean,
- featureNames: Option[Array[String]],
- featureTypes: Option[Array[String]],
- runOnGpu: Boolean) {
-
- private var rawParamMap: Map[String, Any] = _
-
- def setRawParamMap(inputMap: Map[String, Any]): Unit = {
- rawParamMap = inputMap
- }
-
- def toMap: Map[String, Any] = {
- rawParamMap
- }
-}
-
-private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], sc: SparkContext){
-
- private val logger = LogFactory.getLog("XGBoostSpark")
-
- private val isLocal = sc.isLocal
-
- private val overridedParams = overrideParams(rawParams, sc)
-
- validateSparkSslConf()
-
- /**
- * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
- * If so, throw an exception unless this safety measure has been explicitly overridden
- * via conf `xgboost.spark.ignoreSsl`.
- */
- private def validateSparkSslConf(): Unit = {
- val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
- SparkSession.getActiveSession match {
- case Some(ss) =>
- (ss.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean,
- ss.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean)
- case None =>
- (sc.getConf.getBoolean("spark.ssl.enabled", false),
- sc.getConf.getBoolean("xgboost.spark.ignoreSsl", false))
- }
- if (sparkSslEnabled) {
- if (xgboostSparkIgnoreSsl) {
- logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
- s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
- } else {
- throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
- "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
- "To override this protection and still use xgboost-spark at your own risk, " +
- "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
- }
- }
- }
-
- /**
- * we should not include any nested structure in the output of this function as the map is
- * eventually to be feed to xgboost4j layer
- */
- private def overrideParams(
- params: Map[String, Any],
- sc: SparkContext): Map[String, Any] = {
- val coresPerTask = sc.getConf.getInt("spark.task.cpus", 1)
- var overridedParams = params
- if (overridedParams.contains("nthread")) {
- val nThread = overridedParams("nthread").toString.toInt
- require(nThread <= coresPerTask,
- s"the nthread configuration ($nThread) must be no larger than " +
- s"spark.task.cpus ($coresPerTask)")
- } else {
- overridedParams = overridedParams + ("nthread" -> coresPerTask)
- }
-
- val numEarlyStoppingRounds = overridedParams.getOrElse(
- "num_early_stopping_rounds", 0).asInstanceOf[Int]
- overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds
- if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) {
- throw new IllegalArgumentException("custom_eval does not support early stopping")
- }
- overridedParams
- }
-
- /**
- * The Map parameters accepted by estimator's constructor may have string type,
- * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
- * kind of parameters into the correct type in the function.
- *
- * @return XGBoostExecutionParams
- */
- def buildXGBRuntimeParams: XGBoostExecutionParams = {
-
- val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
- val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
- if (obj != null) {
- require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
- "is not defined, you have to specify the objective type as classification or regression" +
- " with a customized objective function")
- }
-
- var trainTestRatio = 1.0
- if (overridedParams.contains("train_test_ratio")) {
- logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
- " pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
- "'eval_set_names'")
- trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
- }
-
- val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
- val round = overridedParams("num_round").asInstanceOf[Int]
- val useExternalMemory = overridedParams
- .getOrElse("use_external_memory", false).asInstanceOf[Boolean]
-
- val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
- val allowNonZeroForMissing = overridedParams
- .getOrElse("allow_non_zero_for_missing", false)
- .asInstanceOf[Boolean]
-
- val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
- val device: Option[String] = overridedParams.get("device").map(_.toString)
- val deviceIsGpu = device.exists(_ == "cuda")
-
- require(!(treeMethod.exists(_ == "approx") && deviceIsGpu),
- "The tree method \"approx\" is not yet supported for Spark GPU cluster")
-
- // back-compatible with "gpu_hist"
- val runOnGpu = treeMethod.exists(_ == "gpu_hist") || deviceIsGpu
-
- val trackerConf = overridedParams.get("tracker_conf") match {
- case None => TrackerConf()
- case Some(conf: TrackerConf) => conf
- case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
- "instance of TrackerConf.")
- }
-
- val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)
-
- val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
- val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)
-
- val earlyStoppingRounds = overridedParams.getOrElse(
- "num_early_stopping_rounds", 0).asInstanceOf[Int]
-
- val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
- .asInstanceOf[Boolean]
-
- val featureNames = if (overridedParams.contains("feature_names")) {
- Some(overridedParams("feature_names").asInstanceOf[Array[String]])
- } else None
- val featureTypes = if (overridedParams.contains("feature_types")){
- Some(overridedParams("feature_types").asInstanceOf[Array[String]])
- } else None
-
- val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
- missing, allowNonZeroForMissing, trackerConf,
- checkpointParam,
- inputParams,
- earlyStoppingRounds,
- cacheTrainingSet,
- device,
- isLocal,
- featureNames,
- featureTypes,
- runOnGpu
- )
- xgbExecParam.setRawParamMap(overridedParams)
- xgbExecParam
- }
-}
+ runOnGpu: Boolean,
+ obj: Option[ObjectiveTrait] = None,
+ eval: Option[EvalTrait] = None)
/**
* A trait to manage stage-level scheduling
*/
-private[spark] trait XGBoostStageLevel extends Serializable {
+private[spark] trait StageLevelScheduling extends Serializable {
private val logger = LogFactory.getLog("XGBoostSpark")
private[spark] def isStandaloneOrLocalCluster(conf: SparkConf): Boolean = {
@@ -255,10 +58,9 @@ private[spark] trait XGBoostStageLevel extends Serializable {
* @param conf spark configurations
* @return Boolean to skip stage-level scheduling or not
*/
- private[spark] def skipStageLevelScheduling(
- sparkVersion: String,
- runOnGpu: Boolean,
- conf: SparkConf): Boolean = {
+ private[spark] def skipStageLevelScheduling(sparkVersion: String,
+ runOnGpu: Boolean,
+ conf: SparkConf): Boolean = {
if (runOnGpu) {
if (sparkVersion < "3.4.0") {
logger.info("Stage-level scheduling in xgboost requires spark version 3.4.0+")
@@ -313,14 +115,13 @@ private[spark] trait XGBoostStageLevel extends Serializable {
* on a single executor simultaneously.
*
* @param sc the spark context
- * @param rdd which rdd to be applied with new resource profile
- * @return the original rdd or the changed rdd
+ * @param rdd the rdd to be applied with new resource profile
+ * @return the original rdd or the modified rdd
*/
- private[spark] def tryStageLevelScheduling(
- sc: SparkContext,
- xgbExecParams: XGBoostExecutionParams,
- rdd: RDD[(Booster, Map[String, Array[Float]])]
- ): RDD[(Booster, Map[String, Array[Float]])] = {
+ private[spark] def tryStageLevelScheduling[T](sc: SparkContext,
+ xgbExecParams: RuntimeParams,
+ rdd: RDD[T]
+ ): RDD[T] = {
val conf = sc.getConf
if (skipStageLevelScheduling(sc.version, xgbExecParams.runOnGpu, conf)) {
@@ -360,7 +161,7 @@ private[spark] trait XGBoostStageLevel extends Serializable {
}
}
-object XGBoost extends XGBoostStageLevel {
+private[spark] object XGBoost extends StageLevelScheduling {
private val logger = LogFactory.getLog("XGBoostSpark")
def getGPUAddrFromResources: Int = {
@@ -383,46 +184,30 @@ object XGBoost extends XGBoostStageLevel {
}
}
- private def buildWatchesAndCheck(buildWatchesFun: () => Watches): Watches = {
- val watches = buildWatchesFun()
- // to workaround the empty partitions in training dataset,
- // this might not be the best efficient implementation, see
- // (https://github.com/dmlc/xgboost/issues/1277)
- if (!watches.toMap.contains("train")) {
- throw new XGBoostError(
- s"detected an empty partition in the training data, partition ID:" +
- s" ${TaskContext.getPartitionId()}")
- }
- watches
- }
-
- private def buildDistributedBooster(
- buildWatches: () => Watches,
- xgbExecutionParam: XGBoostExecutionParams,
- rabitEnv: java.util.Map[String, Object],
- obj: ObjectiveTrait,
- eval: EvalTrait,
- prevBooster: Booster): Iterator[(Booster, Map[String, Array[Float]])] = {
- var watches: Watches = null
- val taskId = TaskContext.getPartitionId().toString
+ private def trainBooster(watches: Watches,
+ runtimeParams: RuntimeParams,
+ xgboostParams: Map[String, Any],
+ rabitEnv: java.util.Map[String, Object]
+ ): Booster = {
+ val partitionId = TaskContext.getPartitionId()
val attempt = TaskContext.get().attemptNumber.toString
- rabitEnv.put("DMLC_TASK_ID", taskId)
- val numRounds = xgbExecutionParam.numRounds
- val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
+ rabitEnv.put("DMLC_TASK_ID", partitionId.toString)
try {
- Communicator.init(rabitEnv)
-
- watches = buildWatchesAndCheck(buildWatches)
+ try {
+ Communicator.init(rabitEnv)
+ } catch {
+ case e: Throwable => logger.error(e)
+ }
+ val numEarlyStoppingRounds = runtimeParams.earlyStoppingRounds
+ val metrics = Array.tabulate(watches.size)(_ =>
+ Array.ofDim[Float](runtimeParams.numRounds))
- val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds
- val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds))
- val externalCheckpointParams = xgbExecutionParam.checkpointParam
+ var params = xgboostParams
- var params = xgbExecutionParam.toMap
- if (xgbExecutionParam.runOnGpu) {
- val gpuId = if (xgbExecutionParam.isLocal) {
+ if (runtimeParams.runOnGpu) {
+ val gpuId = if (runtimeParams.isLocal) {
// For local mode, force gpu id to primary device
0
} else {
@@ -431,126 +216,88 @@ object XGBoost extends XGBoostStageLevel {
logger.info("Leveraging gpu device " + gpuId + " to train")
params = params + ("device" -> s"cuda:$gpuId")
}
-
- val booster = if (makeCheckpoint) {
- SXGBoost.trainAndSaveCheckpoint(
- watches.toMap("train"), params, numRounds,
- watches.toMap, metrics, obj, eval,
- earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
- } else {
- SXGBoost.train(watches.toMap("train"), params, numRounds,
- watches.toMap, metrics, obj, eval,
- earlyStoppingRound = numEarlyStoppingRounds, prevBooster)
- }
- if (TaskContext.get().partitionId() == 0) {
- Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
- } else {
- Iterator.empty
- }
+ SXGBoost.train(watches.toMap("train"), params, runtimeParams.numRounds, watches.toMap,
+ metrics, runtimeParams.obj.getOrElse(null), runtimeParams.eval.getOrElse(null),
+ earlyStoppingRound = numEarlyStoppingRounds)
} catch {
case xgbException: XGBoostError =>
- logger.error(s"XGBooster worker $taskId has failed $attempt times due to ", xgbException)
+ logger.error(s"XGBooster worker $partitionId has failed $attempt " +
+ s"times due to ", xgbException)
throw xgbException
} finally {
Communicator.shutdown()
- if (watches != null) watches.delete()
- }
- }
-
- // Executes the provided code block inside a tracker and then stops the tracker
- private def withTracker[T](nWorkers: Int, conf: TrackerConf)(block: ITracker => T): T = {
- val tracker = new RabitTracker(nWorkers, conf.hostIp, conf.port, conf.timeout)
- require(tracker.start(), "FAULT: Failed to start tracker")
- try {
- block(tracker)
- } finally {
- tracker.stop()
}
}
/**
- * @return A tuple of the booster and the metrics used to build training summary
+ * Train a XGBoost booster with parameters on the dataset
+ *
+ * @param input the input dataset for training
+ * @param runtimeParams the runtime parameters for jvm
+ * @param xgboostParams the xgboost parameters to pass to xgboost library
+ * @return the booster and the metrics
*/
- @throws(classOf[XGBoostError])
- private[spark] def trainDistributed(
- sc: SparkContext,
- buildTrainingData: XGBoostExecutionParams => (RDD[() => Watches], Option[RDD[_]]),
- params: Map[String, Any]):
- (Booster, Map[String, Array[Float]]) = {
+ def train(input: RDD[Watches],
+ runtimeParams: RuntimeParams,
+ xgboostParams: Map[String, Any]): (Booster, Map[String, Array[Float]]) = {
- logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
+ val sc = input.sparkContext
+ logger.info(s"Running XGBoost ${spark.VERSION} with parameters: $xgboostParams")
- val xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc)
- val runtimeParams = xgbParamsFactory.buildXGBRuntimeParams
+ // TODO Rabit tracker exception handling.
+ val trackerConf = runtimeParams.trackerConf
- val prevBooster = runtimeParams.checkpointParam.map { checkpointParam =>
- val checkpointManager = new ExternalCheckpointManager(
- checkpointParam.checkpointPath,
- FileSystem.get(sc.hadoopConfiguration))
- checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds)
- checkpointManager.loadCheckpointAsScalaBooster()
- }.orNull
-
- // Get the training data RDD and the cachedRDD
- val (trainingRDD, optionalCachedRDD) = buildTrainingData(runtimeParams)
+ val tracker = new RabitTracker(runtimeParams.numWorkers,
+ trackerConf.hostIp, trackerConf.port, trackerConf.timeout)
+ require(tracker.start(), "FAULT: Failed to start tracker")
try {
- val (booster, metrics) = withTracker(
- runtimeParams.numWorkers,
- runtimeParams.trackerConf
- ) { tracker =>
- val rabitEnv = tracker.getWorkerArgs()
-
- val boostersAndMetrics = trainingRDD.barrier().mapPartitions { iter =>
- var optionWatches: Option[() => Watches] = None
-
- // take the first Watches to train
- if (iter.hasNext) {
- optionWatches = Some(iter.next())
+ val rabitEnv = tracker.getWorkerArgs()
+
+ val boostersAndMetrics = input.barrier().mapPartitions { iter =>
+ require(iter.hasNext, "Couldn't get DMatrix")
+ val watches = iter.next()
+
+ val metrics = Array.tabulate(watches.size)(_ =>
+ Array.ofDim[Float](runtimeParams.numRounds))
+ try {
+ val booster = trainBooster(watches, runtimeParams, xgboostParams, rabitEnv)
+ if (TaskContext.getPartitionId() == 0) {
+ Iterator(booster -> watches.toMap.keys.zip(metrics).toMap)
+ } else {
+ Iterator.empty
+ }
+ } finally {
+ if (watches != null) {
+ watches.delete()
}
-
- optionWatches.map { buildWatches =>
- buildDistributedBooster(buildWatches,
- runtimeParams, rabitEnv, runtimeParams.obj, runtimeParams.eval, prevBooster)
- }.getOrElse(throw new RuntimeException("No Watches to train"))
}
-
- val boostersAndMetricsWithRes = tryStageLevelScheduling(sc, runtimeParams,
- boostersAndMetrics)
- // The repartition step is to make training stage as ShuffleMapStage, so that when one
- // of the training task fails the training stage can retry. ResultStage won't retry when
- // it fails.
- val (booster, metrics) = boostersAndMetricsWithRes.repartition(1).collect()(0)
- (booster, metrics)
}
- // we should delete the checkpoint directory after a successful training
- runtimeParams.checkpointParam.foreach {
- cpParam =>
- if (!runtimeParams.checkpointParam.get.skipCleanCheckpoint) {
- val checkpointManager = new ExternalCheckpointManager(
- cpParam.checkpointPath,
- FileSystem.get(sc.hadoopConfiguration))
- checkpointManager.cleanPath()
- }
- }
+ val rdd = tryStageLevelScheduling(sc, runtimeParams, boostersAndMetrics)
+ // The repartition step is to make training stage as ShuffleMapStage, so that when one
+ // of the training task fails the training stage can retry. ResultStage won't retry when
+ // it fails.
+ val (booster, metrics) = rdd.repartition(1).collect()(0)
(booster, metrics)
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
- logger.error("the job was aborted due to ", t)
+ logger.error("XGBoost job was aborted due to ", t)
throw t
} finally {
- optionalCachedRDD.foreach(_.unpersist())
+ try {
+ tracker.stop()
+ } catch {
+ case t: Throwable => logger.error(t)
+ }
}
}
-
}
-class Watches private[scala] (
- val datasets: Array[DMatrix],
- val names: Array[String],
- val cacheDirName: Option[String]) {
+class Watches private[scala](val datasets: Array[DMatrix],
+ val names: Array[String],
+ val cacheDirName: Option[String]) {
def toMap: Map[String, DMatrix] = {
names.zip(datasets).toMap.filter { case (_, matrix) => matrix.rowNum > 0 }
@@ -568,211 +315,14 @@ class Watches private[scala] (
override def toString: String = toMap.toString
}
-private object Watches {
-
- private def fromBaseMarginsToArray(baseMargins: Iterator[Float]): Option[Array[Float]] = {
- val builder = new mutable.ArrayBuilder.ofFloat()
- var nTotal = 0
- var nUndefined = 0
- while (baseMargins.hasNext) {
- nTotal += 1
- val baseMargin = baseMargins.next()
- if (baseMargin.isNaN) {
- nUndefined += 1 // don't waste space for all-NaNs.
- } else {
- builder += baseMargin
- }
- }
- if (nUndefined == nTotal) {
- None
- } else if (nUndefined == 0) {
- Some(builder.result())
- } else {
- throw new IllegalArgumentException(
- s"Encountered a partition with $nUndefined NaN base margin values. " +
- s"If you want to specify base margin, ensure all values are non-NaN.")
- }
- }
-
- def buildWatches(
- nameAndLabeledPointSets: Iterator[(String, Iterator[XGBLabeledPoint])],
- cachedDirName: Option[String]): Watches = {
- val dms = nameAndLabeledPointSets.map {
- case (name, labeledPoints) =>
- val baseMargins = new mutable.ArrayBuilder.ofFloat
- val duplicatedItr = labeledPoints.map(labeledPoint => {
- baseMargins += labeledPoint.baseMargin
- labeledPoint
- })
- val dMatrix = new DMatrix(duplicatedItr, cachedDirName.map(_ + s"/$name").orNull)
- val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
- if (baseMargin.isDefined) {
- dMatrix.setBaseMargin(baseMargin.get)
- }
- (name, dMatrix)
- }.toArray
- new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
- }
-
- def buildWatches(
- xgbExecutionParams: XGBoostExecutionParams,
- labeledPoints: Iterator[XGBLabeledPoint],
- cacheDirName: Option[String]): Watches = {
- val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
- val seed = xgbExecutionParams.xgbInputParams.seed
- val r = new Random(seed)
- val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
- val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
- val testBaseMargins = new mutable.ArrayBuilder.ofFloat
- val trainPoints = labeledPoints.filter { labeledPoint =>
- val accepted = r.nextDouble() <= trainTestRatio
- if (!accepted) {
- testPoints += labeledPoint
- testBaseMargins += labeledPoint.baseMargin
- } else {
- trainBaseMargins += labeledPoint.baseMargin
- }
- accepted
- }
- val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
- val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
-
- val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
- val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
- if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
- if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
-
- if (xgbExecutionParams.featureNames.isDefined) {
- trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
- testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
- }
-
- if (xgbExecutionParams.featureTypes.isDefined) {
- trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
- testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
- }
-
- new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
- }
-
- def buildWatchesWithGroup(
- nameAndlabeledPointGroupSets: Iterator[(String, Iterator[Array[XGBLabeledPoint]])],
- cachedDirName: Option[String]): Watches = {
- val dms = nameAndlabeledPointGroupSets.map {
- case (name, labeledPointsGroups) =>
- val baseMargins = new mutable.ArrayBuilder.ofFloat
- val groupsInfo = new mutable.ArrayBuilder.ofInt
- val weights = new mutable.ArrayBuilder.ofFloat
- val iter = labeledPointsGroups.filter(labeledPointGroup => {
- var groupWeight = -1.0f
- var groupSize = 0
- labeledPointGroup.map { labeledPoint => {
- if (groupWeight < 0) {
- groupWeight = labeledPoint.weight
- } else if (groupWeight != labeledPoint.weight) {
- throw new IllegalArgumentException("the instances in the same group have to be" +
- s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
- }
- baseMargins += labeledPoint.baseMargin
- groupSize += 1
- labeledPoint
- }
- }
- weights += groupWeight
- groupsInfo += groupSize
- true
- })
- val dMatrix = new DMatrix(iter.flatMap(_.iterator), cachedDirName.map(_ + s"/$name").orNull)
- val baseMargin = fromBaseMarginsToArray(baseMargins.result().iterator)
- if (baseMargin.isDefined) {
- dMatrix.setBaseMargin(baseMargin.get)
- }
- dMatrix.setGroup(groupsInfo.result())
- dMatrix.setWeight(weights.result())
- (name, dMatrix)
- }.toArray
- new Watches(dms.map(_._2), dms.map(_._1), cachedDirName)
- }
-
- def buildWatchesWithGroup(
- xgbExecutionParams: XGBoostExecutionParams,
- labeledPointGroups: Iterator[Array[XGBLabeledPoint]],
- cacheDirName: Option[String]): Watches = {
- val trainTestRatio = xgbExecutionParams.xgbInputParams.trainTestRatio
- val seed = xgbExecutionParams.xgbInputParams.seed
- val r = new Random(seed)
- val testPoints = mutable.ArrayBuilder.make[XGBLabeledPoint]
- val trainBaseMargins = new mutable.ArrayBuilder.ofFloat
- val testBaseMargins = new mutable.ArrayBuilder.ofFloat
-
- val trainGroups = new mutable.ArrayBuilder.ofInt
- val testGroups = new mutable.ArrayBuilder.ofInt
-
- val trainWeights = new mutable.ArrayBuilder.ofFloat
- val testWeights = new mutable.ArrayBuilder.ofFloat
-
- val trainLabelPointGroups = labeledPointGroups.filter { labeledPointGroup =>
- val accepted = r.nextDouble() <= trainTestRatio
- if (!accepted) {
- var groupWeight = -1.0f
- var groupSize = 0
- labeledPointGroup.foreach(labeledPoint => {
- testPoints += labeledPoint
- testBaseMargins += labeledPoint.baseMargin
- if (groupWeight < 0) {
- groupWeight = labeledPoint.weight
- } else if (labeledPoint.weight != groupWeight) {
- throw new IllegalArgumentException("the instances in the same group have to be" +
- s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
- }
- groupSize += 1
- })
- testWeights += groupWeight
- testGroups += groupSize
- } else {
- var groupWeight = -1.0f
- var groupSize = 0
- labeledPointGroup.foreach { labeledPoint => {
- if (groupWeight < 0) {
- groupWeight = labeledPoint.weight
- } else if (labeledPoint.weight != groupWeight) {
- throw new IllegalArgumentException("the instances in the same group have to be" +
- s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
- }
- trainBaseMargins += labeledPoint.baseMargin
- groupSize += 1
- }}
- trainWeights += groupWeight
- trainGroups += groupSize
- }
- accepted
- }
-
- val trainPoints = trainLabelPointGroups.flatMap(_.iterator)
- val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
- trainMatrix.setGroup(trainGroups.result())
- trainMatrix.setWeight(trainWeights.result())
-
- val testMatrix = new DMatrix(testPoints.result().iterator, cacheDirName.map(_ + "/test").orNull)
- if (trainTestRatio < 1.0) {
- testMatrix.setGroup(testGroups.result())
- testMatrix.setWeight(testWeights.result())
- }
-
- val trainMargin = fromBaseMarginsToArray(trainBaseMargins.result().iterator)
- val testMargin = fromBaseMarginsToArray(testBaseMargins.result().iterator)
- if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
- if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
-
- if (xgbExecutionParams.featureNames.isDefined) {
- trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
- testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
- }
- if (xgbExecutionParams.featureTypes.isDefined) {
- trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
- testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
- }
-
- new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
- }
-}
+/**
+ * Rabit tracker configurations.
+ *
+ * @param timeout The number of seconds before timeout waiting for workers to connect. and
+ * for the tracker to shutdown.
+ * @param hostIp The Rabit Tracker host IP address.
+ * This is only needed if the host IP cannot be automatically guessed.
+ * @param port The port number for the tracker to listen to. Use a system allocated one by
+ * default.
+ */
+private[spark] case class TrackerConf(timeout: Int = 0, hostIp: String = "", port: Int = 0)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
index ec8766e407f9..2a4caedeae5f 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,490 +16,190 @@
package ml.dmlc.xgboost4j.scala.spark
-import ml.dmlc.xgboost4j.scala.spark.params._
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, EvalTrait, ObjectiveTrait, XGBoost => SXGBoost}
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.ml.classification._
-import org.apache.spark.ml.linalg._
-import org.apache.spark.ml.util._
-import org.apache.spark.sql._
-import org.apache.spark.sql.functions._
-import scala.collection.{Iterator, mutable}
+import scala.collection.mutable
+import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
-import org.apache.spark.sql.types.StructType
-
-class XGBoostClassifier (
- override val uid: String,
- private[spark] val xgboostParams: Map[String, Any])
- extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
- with XGBoostClassifierParams with DefaultParamsWritable {
-
- def this() = this(Identifiable.randomUID("xgbc"), Map[String, Any]())
-
- def this(uid: String) = this(uid, Map[String, Any]())
-
- def this(xgboostParams: Map[String, Any]) = this(
- Identifiable.randomUID("xgbc"), xgboostParams)
-
- XGBoost2MLlibParams(xgboostParams)
-
- def setWeightCol(value: String): this.type = set(weightCol, value)
-
- def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
-
- def setNumClass(value: Int): this.type = set(numClass, value)
-
- // setters for general params
- def setNumRound(value: Int): this.type = set(numRound, value)
-
- def setNumWorkers(value: Int): this.type = set(numWorkers, value)
-
- def setNthread(value: Int): this.type = set(nthread, value)
-
- def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
-
- def setSilent(value: Int): this.type = set(silent, value)
-
- def setMissing(value: Float): this.type = set(missing, value)
-
- def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
-
- def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
-
- def setSeed(value: Long): this.type = set(seed, value)
-
- def setEta(value: Double): this.type = set(eta, value)
-
- def setGamma(value: Double): this.type = set(gamma, value)
-
- def setMaxDepth(value: Int): this.type = set(maxDepth, value)
-
- def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
-
- def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
-
- def setSubsample(value: Double): this.type = set(subsample, value)
-
- def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
-
- def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
-
- def setLambda(value: Double): this.type = set(lambda, value)
-
- def setAlpha(value: Double): this.type = set(alpha, value)
-
- def setTreeMethod(value: String): this.type = set(treeMethod, value)
-
- def setDevice(value: String): this.type = set(device, value)
-
- def setGrowPolicy(value: String): this.type = set(growPolicy, value)
-
- def setMaxBins(value: Int): this.type = set(maxBins, value)
-
- def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
-
- def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
-
- def setSampleType(value: String): this.type = set(sampleType, value)
-
- def setNormalizeType(value: String): this.type = set(normalizeType, value)
-
- def setRateDrop(value: Double): this.type = set(rateDrop, value)
-
- def setSkipDrop(value: Double): this.type = set(skipDrop, value)
+import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
+import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions.{col, udf}
+import org.json4s.DefaultFormats
- def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
+import ml.dmlc.xgboost4j.scala.Booster
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
- // setters for learning params
- def setObjective(value: String): this.type = set(objective, value)
-
- def setObjectiveType(value: String): this.type = set(objectiveType, value)
-
- def setBaseScore(value: Double): this.type = set(baseScore, value)
-
- def setEvalMetric(value: String): this.type = set(evalMetric, value)
-
- def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
-
- def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
-
- def setMaximizeEvaluationMetrics(value: Boolean): this.type =
- set(maximizeEvaluationMetrics, value)
-
- def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
-
- def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
+class XGBoostClassifier(override val uid: String,
+ private[spark] val xgboostParams: Map[String, Any])
+ extends ProbabilisticClassifier[Vector, XGBoostClassifier, XGBoostClassificationModel]
+ with XGBoostEstimator[XGBoostClassifier, XGBoostClassificationModel]
+ with XGBProbabilisticClassifierParams[XGBoostClassifier] {
- def setAllowNonZeroForMissing(value: Boolean): this.type = set(
- allowNonZeroForMissing,
- value
- )
+ def this() = this(XGBoostClassifier._uid, Map.empty)
- def setSinglePrecisionHistogram(value: Boolean): this.type =
- set(singlePrecisionHistogram, value)
+ def this(uid: String) = this(uid, Map.empty)
- def setFeatureNames(value: Array[String]): this.type =
- set(featureNames, value)
+ def this(xgboostParams: Map[String, Any]) = this(XGBoostClassifier._uid, xgboostParams)
- def setFeatureTypes(value: Array[String]): this.type =
- set(featureTypes, value)
+ xgboost2SparkParams(xgboostParams)
- // called at the start of fit/train when 'eval_metric' is not defined
- private def setupDefaultEvalMetric(): String = {
- require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
- if ($(objective).startsWith("multi")) {
- // multi
- "mlogloss"
- } else {
- // binary
- "logloss"
- }
- }
+ private var numberClasses = 0
- // Callback from PreXGBoost
- private[spark] def transformSchemaInternal(schema: StructType): StructType = {
- if (isFeaturesColSet(schema)) {
- // User has vectorized the features into VectorUDT.
- super.transformSchema(schema)
+ private def validateObjective(dataset: Dataset[_]): Unit = {
+ // If the objective is set explicitly, it must be in binaryClassificationObjs and
+ // multiClassificationObjs
+ val obj = if (isSet(objective)) {
+ val tmpObj = getObjective
+ val supportedObjs = BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq
+ require(supportedObjs.contains(tmpObj),
+ s"Wrong objective for XGBoostClassifier, supported objs: ${supportedObjs.mkString(",")}")
+ Some(tmpObj)
} else {
- transformSchemaWithFeaturesCols(true, schema)
+ None
}
- }
-
- override def transformSchema(schema: StructType): StructType = {
- PreXGBoost.transformSchema(this, schema)
- }
- override protected def train(dataset: Dataset[_]): XGBoostClassificationModel = {
- val _numClasses = getNumClasses(dataset)
- if (isDefined(numClass) && $(numClass) != _numClasses) {
- throw new Exception("The number of classes in dataset doesn't match " +
- "\'num_class\' in xgboost params.")
+ def inferNumClasses: Int = {
+ var num = getNumClass
+ // Infer num class if num class is not set explicitly.
+ // Note that user sets the num classes explicitly, we're not checking that.
+ if (num == 0) {
+ num = SparkUtils.getNumClasses(dataset, getLabelCol)
+ }
+ require(num > 0)
+ num
}
- if (_numClasses == 2) {
- if (!isDefined(objective)) {
- // If user doesn't set objective, force it to binary:logistic
- setObjective("binary:logistic")
+ // objective is set explicitly.
+ if (obj.isDefined) {
+ if (MULTICLASSIFICATION_OBJS.contains(getObjective)) {
+ numberClasses = inferNumClasses
+ setNumClass(numberClasses)
+ } else {
+ numberClasses = 2
+ // binary classification doesn't require num_class be set
+ require(!isSet(numClass), "num_class is not allowed for binary classification")
}
- } else if (_numClasses > 2) {
- if (!isDefined(objective)) {
- // If user doesn't set objective, force it to multi:softprob
+ } else {
+ // infer the objective according to the num_class
+ numberClasses = inferNumClasses
+ if (numberClasses <= 2) {
+ setObjective("binary:logistic")
+ logger.warn("Inferred for binary classification, set the objective to binary:logistic")
+ require(!isSet(numClass), "num_class is not allowed for binary classification")
+ } else {
+ logger.warn("Inferred for multi classification, set the objective to multi:softprob")
setObjective("multi:softprob")
+ setNumClass(numberClasses)
}
}
-
- if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
- set(evalMetric, setupDefaultEvalMetric())
- }
-
- if (isDefined(customObj) && $(customObj) != null) {
- set(objectiveType, "classification")
- }
-
- // Packing with all params plus params user defined
- val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
- val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
- transformSchema(dataset.schema, logging = true)
-
- // All non-null param maps in XGBoostClassifier are in derivedXGBParamMap.
- val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
- buildTrainingData, derivedXGBParamMap)
-
- val model = new XGBoostClassificationModel(uid, _numClasses, _booster)
- val summary = XGBoostTrainingSummary(_metrics)
- model.setSummary(summary)
- model
}
- override def copy(extra: ParamMap): XGBoostClassifier = defaultCopy(extra)
-}
-
-object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
-
- override def load(path: String): XGBoostClassifier = super.load(path)
-}
-
-class XGBoostClassificationModel private[ml](
- override val uid: String,
- override val numClasses: Int,
- private[scala] val _booster: Booster)
- extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
- with XGBoostClassifierParams with InferenceParams
- with MLWritable with Serializable {
-
- import XGBoostClassificationModel._
-
- // only called in copy()
- def this(uid: String) = this(uid, 2, null)
-
- /**
- * Get the native booster instance of this model.
- * This is used to call low-level APIs on native booster, such as "getFeatureScore".
- */
- def nativeBooster: Booster = _booster
-
- private var trainingSummary: Option[XGBoostTrainingSummary] = None
-
/**
- * Returns summary (e.g. train/test objective history) of model on the
- * training set. An exception is thrown if no summary is available.
+ * Validate the parameters before training, throw exception if possible
*/
- def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
- throw new IllegalStateException("No training summary available for this XGBoostModel")
- }
-
- private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
- trainingSummary = Some(summary)
- this
+ override protected[spark] def validate(dataset: Dataset[_]): Unit = {
+ super.validate(dataset)
+ validateObjective(dataset)
}
- def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
-
- def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
-
- def setTreeLimit(value: Int): this.type = set(treeLimit, value)
-
- def setMissing(value: Float): this.type = set(missing, value)
-
- def setAllowNonZeroForMissing(value: Boolean): this.type = set(
- allowNonZeroForMissing,
- value
- )
-
- def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
-
- /**
- * Single instance prediction.
- * Note: The performance is not ideal, use it carefully!
- */
- override def predict(features: Vector): Double = {
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
- val dm = new DMatrix(processMissingValues(
- Iterator(features.asXGB),
- $(missing),
- $(allowNonZeroForMissing)
- ))
- val probability = _booster.predict(data = dm)(0).map(_.toDouble)
- if (numClasses == 2) {
- math.round(probability(0))
- } else {
- probability2prediction(Vectors.dense(probability))
- }
+ override protected def createModel(booster: Booster, summary: XGBoostTrainingSummary):
+ XGBoostClassificationModel = {
+ new XGBoostClassificationModel(uid, numberClasses, booster, Option(summary))
}
- // Actually we don't use this function at all, to make it pass compiler check.
- override def predictRaw(features: Vector): Vector = {
- throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
- }
+}
- // Actually we don't use this function at all, to make it pass compiler check.
- override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
- throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
- }
+object XGBoostClassifier extends DefaultParamsReadable[XGBoostClassifier] {
+ private val _uid = Identifiable.randomUID("xgbc")
+}
- private[scala] def produceResultIterator(
- originalRowItr: Iterator[Row],
- rawPredictionItr: Iterator[Row],
- probabilityItr: Iterator[Row],
- predLeafItr: Iterator[Row],
- predContribItr: Iterator[Row]): Iterator[Row] = {
- // the following implementation is to be improved
- if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
- isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
- originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
- map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
- contribs: Row) =>
- Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
- contribs.toSeq)
- }
- } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
- (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
- originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
- map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
- Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
- }
- } else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
- isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
- originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
- map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
- Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
+class XGBoostClassificationModel private[ml](
+ val uid: String,
+ val numClasses: Int,
+ val nativeBooster: Booster,
+ val summary: Option[XGBoostTrainingSummary] = None
+) extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
+ with XGBoostModel[XGBoostClassificationModel]
+ with XGBProbabilisticClassifierParams[XGBoostClassificationModel] {
+
+ def this(uid: String) = this(uid, 0, null)
+
+ override protected[spark] def postTransform(dataset: Dataset[_],
+ pred: PredictedColumns): Dataset[_] = {
+ var output = super.postTransform(dataset, pred)
+
+ // Always use probability col to get the prediction
+
+ if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
+ if (getObjective == "multi:softmax") {
+ // For objective=multi:softmax scenario, there is no probability predicted from xgboost.
+ // Instead, the probability column will be filled with real prediction
+ val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
+ probability(0)
}
- } else {
- originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
- case ((originals: Row, rawPrediction: Row), probability: Row) =>
- Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
- }
- }
- }
-
- private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
- Array[Iterator[Row]] = {
- val rawPredictionItr = {
- booster.predict(dm, outPutMargin = true, $(treeLimit)).
- map(Row(_)).iterator
- }
- val probabilityItr = {
- booster.predict(dm, outPutMargin = false, $(treeLimit)).
- map(Row(_)).iterator
- }
- val predLeafItr = {
- if (isDefined(leafPredictionCol)) {
- booster.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
+ output = output.withColumn(getPredictionCol, predictUDF(col(TMP_TRANSFORMED_COL)))
} else {
- Iterator()
- }
- }
- val predContribItr = {
- if (isDefined(contribPredictionCol)) {
- booster.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
- } else {
- Iterator()
+ val predCol = udf { probability: mutable.WrappedArray[Float] =>
+ val prob = probability.map(_.toDouble).toArray
+ val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
+ probability2prediction(Vectors.dense(probabilities))
+ }
+ output = output.withColumn(getPredictionCol, predCol(col(TMP_TRANSFORMED_COL)))
}
}
- Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
- }
-
- private[spark] def transformSchemaInternal(schema: StructType): StructType = {
- if (isFeaturesColSet(schema)) {
- // User has vectorized the features into VectorUDT.
- super.transformSchema(schema)
- } else {
- transformSchemaWithFeaturesCols(false, schema)
- }
- }
-
- override def transformSchema(schema: StructType): StructType = {
- PreXGBoost.transformSchema(this, schema)
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
- if (isDefined(thresholds)) {
- require($(thresholds).length == numClasses, this.getClass.getSimpleName +
- ".transform() called with non-matching numClasses and thresholds.length." +
- s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
- }
-
- // Output selected columns only.
- // This is a bit complicated since it tries to avoid repeated computation.
- var outputData = PreXGBoost.transformDataset(this, dataset)
- var numColsOutput = 0
-
- val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
- val raw = rawPrediction.map(_.toDouble).toArray
- val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
- Vectors.dense(rawPredictions)
- }
- if ($(rawPredictionCol).nonEmpty) {
- outputData = outputData
- .withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
- numColsOutput += 1
- }
-
- if (getObjective.equals("multi:softmax")) {
- // For objective=multi:softmax scenario, there is no probability predicted from xgboost.
- // Instead, the probability column will be filled with real prediction
- val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
- probability(0)
- }
- if ($(predictionCol).nonEmpty) {
- outputData = outputData
- .withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
- numColsOutput += 1
- }
-
- } else {
+ if (isDefinedNonEmpty(probabilityCol) && pred.predTmp) {
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
- if ($(probabilityCol).nonEmpty) {
- outputData = outputData
- .withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
- numColsOutput += 1
- }
+ output = output.withColumn(TMP_TRANSFORMED_COL,
+ probabilityUDF(output.col(TMP_TRANSFORMED_COL)))
+ .withColumnRenamed(TMP_TRANSFORMED_COL, getProbabilityCol)
+ }
- val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
- // From XGBoost probability to MLlib prediction
- val prob = probability.map(_.toDouble).toArray
- val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
- probability2prediction(Vectors.dense(probabilities))
- }
- if ($(predictionCol).nonEmpty) {
- outputData = outputData
- .withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
- numColsOutput += 1
+ if (pred.predRaw) {
+ val rawPredictionUDF = udf { raw: mutable.WrappedArray[Float] =>
+ val rawF = raw.map(_.toDouble).toArray
+ val rawPredictions = if (numClasses == 2) Array(-rawF(0), rawF(0)) else rawF
+ Vectors.dense(rawPredictions)
}
+ output = output.withColumn(getRawPredictionCol,
+ rawPredictionUDF(output.col(getRawPredictionCol)))
}
- if (numColsOutput == 0) {
- this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
- " since no output columns were set.")
- }
- outputData
- .toDF
- .drop(col(_rawPredictionCol))
- .drop(col(_probabilityCol))
+ output.drop(TMP_TRANSFORMED_COL)
}
override def copy(extra: ParamMap): XGBoostClassificationModel = {
- val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
- newModel.setSummary(summary).setParent(parent)
+ val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses,
+ nativeBooster, summary), extra)
+ newModel.setParent(parent)
}
- override def write: MLWriter =
- new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
-}
-
-object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
-
- private[scala] val _rawPredictionCol = "_rawPrediction"
- private[scala] val _probabilityCol = "_probability"
-
- override def read: MLReader[XGBoostClassificationModel] = new XGBoostClassificationModelReader
-
- override def load(path: String): XGBoostClassificationModel = super.load(path)
-
- private[XGBoostClassificationModel]
- class XGBoostClassificationModelWriter(instance: XGBoostClassificationModel)
- extends XGBoostWriter {
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
-
- // Save model data
- val dataPath = new Path(path, "data").toString
- val internalPath = new Path(dataPath, "XGBoostClassificationModel")
- val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
- instance._booster.saveModel(outputStream, getModelFormat())
- outputStream.close()
- }
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ throw new Exception("XGBoost-Spark does not support \'raw2probabilityInPlace\'")
}
- private class XGBoostClassificationModelReader extends MLReader[XGBoostClassificationModel] {
+ override def predictRaw(features: Vector): Vector =
+ throw new Exception("XGBoost-Spark does not support \'predictRaw\'")
- /** Checked against metadata when loading model */
- private val className = classOf[XGBoostClassificationModel].getName
+}
- override def load(path: String): XGBoostClassificationModel = {
- implicit val sc = super.sparkSession.sparkContext
+object XGBoostClassificationModel extends MLReadable[XGBoostClassificationModel] {
- val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
+ override def read: MLReader[XGBoostClassificationModel] = new ModelReader
- val dataPath = new Path(path, "data").toString
- val internalPath = new Path(dataPath, "XGBoostClassificationModel")
- val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
- val numClasses = DefaultXGBoostParamsReader.getNumClass(metadata, dataInStream)
- val booster = SXGBoost.loadModel(dataInStream)
- val model = new XGBoostClassificationModel(metadata.uid, numClasses, booster)
- DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
+ private class ModelReader extends XGBoostModelReader[XGBoostClassificationModel] {
+ override def load(path: String): XGBoostClassificationModel = {
+ val xgbModel = loadBooster(path)
+ val meta = SparkUtils.loadMetadata(path, sc)
+ implicit val format = DefaultFormats
+ val numClasses = (meta.params \ "numClass").extractOpt[Int].getOrElse(2)
+ val model = new XGBoostClassificationModel(meta.uid, numClasses, xgbModel)
+ meta.getAndSetParams(model)
model
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
new file mode 100644
index 000000000000..cd5fa0865ea0
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala
@@ -0,0 +1,622 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.util.ServiceLoader
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.commons.logging.LogFactory
+import org.apache.hadoop.fs.Path
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.functions.array_to_vector
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.util.{DefaultParamsWritable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.ml.xgboost.{SparkUtils, XGBProbabilisticClassifierParams}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types._
+
+import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
+import ml.dmlc.xgboost4j.java.{Booster => JBooster}
+import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
+import ml.dmlc.xgboost4j.scala.spark.Utils.MLVectorToXGBLabeledPoint
+import ml.dmlc.xgboost4j.scala.spark.params._
+
+/**
+ * Hold the column index
+ */
+private[spark] case class ColumnIndices(
+ labelId: Int,
+ featureId: Option[Int], // the feature type is VectorUDT or Array
+ featureIds: Option[Seq[Int]], // the feature type is columnar
+ weightId: Option[Int],
+ marginId: Option[Int],
+ groupId: Option[Int])
+
+private[spark] trait NonParamVariables[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]] {
+
+ private var dataset: Option[Dataset[_]] = None
+
+ def setEvalDataset(ds: Dataset[_]): T = {
+ this.dataset = Some(ds)
+ this.asInstanceOf[T]
+ }
+
+ def getEvalDataset(): Option[Dataset[_]] = {
+ this.dataset
+ }
+}
+
+private[spark] trait PluginMixin {
+ // Find the XGBoostPlugin by ServiceLoader
+ private val plugin: Option[XGBoostPlugin] = {
+ val classLoader = Option(Thread.currentThread().getContextClassLoader)
+ .getOrElse(getClass.getClassLoader)
+
+ val serviceLoader = ServiceLoader.load(classOf[XGBoostPlugin], classLoader)
+
+ // For now, we only trust GpuXGBoostPlugin.
+ serviceLoader.asScala.filter(x => x.getClass.getName.equals(
+ "ml.dmlc.xgboost4j.scala.spark.GpuXGBoostPlugin")).toList match {
+ case Nil => None
+ case head :: Nil =>
+ Some(head)
+ case _ => None
+ }
+ }
+
+ /** Visible for testing */
+ protected[spark] def getPlugin: Option[XGBoostPlugin] = plugin
+
+ protected def isPluginEnabled(dataset: Dataset[_]): Boolean = {
+ plugin.map(_.isEnabled(dataset)).getOrElse(false)
+ }
+}
+
+private[spark] trait XGBoostEstimator[
+ Learner <: XGBoostEstimator[Learner, M], M <: XGBoostModel[M]] extends Estimator[M]
+ with XGBoostParams[Learner] with SparkParams[Learner] with ParamUtils[Learner]
+ with NonParamVariables[Learner, M] with ParamMapConversion with DefaultParamsWritable
+ with PluginMixin {
+
+ protected val logger = LogFactory.getLog("XGBoostSpark")
+
+ /**
+ * Cast the field in schema to the desired data type.
+ *
+ * @param dataset the input dataset
+ * @param name which column will be casted to float if possible.
+ * @param targetType the targetd data type
+ * @return Dataset
+ */
+ private[spark] def castIfNeeded(schema: StructType,
+ name: String,
+ targetType: DataType = FloatType): Column = {
+ if (!(schema(name).dataType == targetType)) {
+ val meta = schema(name).metadata
+ col(name).as(name, meta).cast(targetType)
+ } else {
+ col(name)
+ }
+ }
+
+ /**
+ * Repartition the dataset to the numWorkers if needed.
+ *
+ * @param dataset to be repartition
+ * @return the repartitioned dataset
+ */
+ private[spark] def repartitionIfNeeded(dataset: Dataset[_]): Dataset[_] = {
+ val numPartitions = dataset.rdd.getNumPartitions
+ if (getForceRepartition || getNumWorkers != numPartitions) {
+ dataset.repartition(getNumWorkers)
+ } else {
+ dataset
+ }
+ }
+
+ /**
+ * Build the columns indices.
+ */
+ private[spark] def buildColumnIndices(schema: StructType): ColumnIndices = {
+ // Get feature id(s)
+ val (featureIds: Option[Seq[Int]], featureId: Option[Int]) =
+ if (getFeaturesCols.length != 0) {
+ (Some(getFeaturesCols.map(schema.fieldIndex).toSeq), None)
+ } else {
+ (None, Some(schema.fieldIndex(getFeaturesCol)))
+ }
+
+ // function to get the column id according to the parameter
+ def columnId(param: Param[String]): Option[Int] = {
+ if (isDefinedNonEmpty(param)) {
+ Some(schema.fieldIndex($(param)))
+ } else {
+ None
+ }
+ }
+
+ // Special handle for group
+ val groupId: Option[Int] = this match {
+ case p: HasGroupCol => columnId(p.groupCol)
+ case _ => None
+ }
+
+ ColumnIndices(
+ labelId = columnId(labelCol).get,
+ featureId = featureId,
+ featureIds = featureIds,
+ columnId(weightCol),
+ columnId(baseMarginCol),
+ groupId)
+ }
+
+ /**
+ * Preprocess the dataset to meet the xgboost input requirement
+ *
+ * @param dataset
+ * @return
+ */
+ private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
+
+ // Columns to be selected for XGBoost training
+ val selectedCols: ArrayBuffer[Column] = ArrayBuffer.empty
+ val schema = dataset.schema
+
+ def selectCol(c: Param[String], targetType: DataType) = {
+ if (isDefinedNonEmpty(c)) {
+ // Validation col should be a boolean column.
+ if (c == featuresCol) {
+ selectedCols.append(col($(c)))
+ } else {
+ selectedCols.append(castIfNeeded(schema, $(c), targetType))
+ }
+ }
+ }
+
+ Seq(labelCol, featuresCol, weightCol, baseMarginCol).foreach(p => selectCol(p, FloatType))
+ this match {
+ case p: HasGroupCol => selectCol(p.groupCol, IntegerType)
+ case _ =>
+ }
+ val input = repartitionIfNeeded(dataset.select(selectedCols.toArray: _*))
+
+ val columnIndices = buildColumnIndices(input.schema)
+ (input, columnIndices)
+ }
+
+ /** visible for testing */
+ private[spark] def toXGBLabeledPoint(dataset: Dataset[_],
+ columnIndexes: ColumnIndices): RDD[XGBLabeledPoint] = {
+ dataset.toDF().rdd.map { row =>
+ val features = row.getAs[Vector](columnIndexes.featureId.get)
+ val label = row.getFloat(columnIndexes.labelId)
+ val weight = columnIndexes.weightId.map(row.getFloat).getOrElse(1.0f)
+ val baseMargin = columnIndexes.marginId.map(row.getFloat).getOrElse(Float.NaN)
+ val group = columnIndexes.groupId.map(row.getInt).getOrElse(-1)
+ // To make "0" meaningful, we convert sparse vector if possible to dense to create DMatrix.
+ val values = features.toArray.map(_.toFloat)
+ XGBLabeledPoint(label, values.length, null, values, weight, group, baseMargin)
+ }
+ }
+
+ /**
+ * Convert the dataframe to RDD, visible to testing
+ *
+ * @param dataset
+ * @param columnsOrder the order of columns including weight/group/base margin ...
+ * @return RDD
+ */
+ private[spark] def toRdd(dataset: Dataset[_], columnIndices: ColumnIndices): RDD[Watches] = {
+ val trainRDD = toXGBLabeledPoint(dataset, columnIndices)
+
+ val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames)
+ val featureTypes = if (getFeatureTypes.isEmpty) None else Some(getFeatureTypes)
+
+ val missing = getMissing
+
+ // Transform the labeledpoint to get margins/groups and build DMatrix
+ // TODO support basemargin for multiclassification
+ // TODO and optimization, move it into JNI.
+ def buildDMatrix(iter: Iterator[XGBLabeledPoint]) = {
+ val dmatrix = if (columnIndices.marginId.isDefined || columnIndices.groupId.isDefined) {
+ val margins = new mutable.ArrayBuilder.ofFloat
+ val groups = new mutable.ArrayBuilder.ofInt
+ val groupWeights = new mutable.ArrayBuilder.ofFloat
+ var prevGroup = -101010
+ var prevWeight = -1.0f
+ var groupSize = 0
+ val transformedIter = iter.map { labeledPoint =>
+ if (columnIndices.marginId.isDefined) {
+ margins += labeledPoint.baseMargin
+ }
+ if (columnIndices.groupId.isDefined) {
+ if (prevGroup != labeledPoint.group) {
+ // starting with new group
+ if (prevGroup != -101010) {
+ // write the previous group
+ groups += groupSize
+ groupWeights += prevWeight
+ }
+ groupSize = 1
+ prevWeight = labeledPoint.weight
+ prevGroup = labeledPoint.group
+ } else {
+ // for the same group
+ if (prevWeight != labeledPoint.weight) {
+ throw new IllegalArgumentException("the instances in the same group have to be" +
+ s" assigned with the same weight (unexpected weight ${labeledPoint.weight}")
+ }
+ groupSize = groupSize + 1
+ }
+ }
+ labeledPoint
+ }
+ val dm = new DMatrix(transformedIter, null, missing)
+ columnIndices.marginId.foreach(_ => dm.setBaseMargin(margins.result()))
+ if (columnIndices.groupId.isDefined) {
+ if (prevGroup != -101011) {
+ // write the last group
+ groups += groupSize
+ groupWeights += prevWeight
+ }
+ dm.setGroup(groups.result())
+ // The new DMatrix() will set the weights for each instance. But ranking requires
+ // 1 weight for each group, so need to reset the weight.
+ // This is definitely optimized by moving setting group/base margin into JNI.
+ dm.setWeight(groupWeights.result())
+ }
+ dm
+ } else {
+ new DMatrix(iter, null, missing)
+ }
+ featureTypes.foreach(dmatrix.setFeatureTypes)
+ featureNames.foreach(dmatrix.setFeatureNames)
+ dmatrix
+ }
+
+ getEvalDataset().map { eval =>
+ val (evalDf, _) = preprocess(eval)
+ val evalRDD = toXGBLabeledPoint(evalDf, columnIndices)
+ trainRDD.zipPartitions(evalRDD) { (left, right) =>
+ val trainDMatrix = buildDMatrix(left)
+ val evalDMatrix = buildDMatrix(right)
+ val watches = new Watches(Array(trainDMatrix, evalDMatrix),
+ Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
+ Iterator.single(watches)
+ }
+ }.getOrElse(
+ trainRDD.mapPartitions { iter =>
+ val dm = buildDMatrix(iter)
+ val watches = new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
+ Iterator.single(watches)
+ }
+ )
+ }
+
+ protected def createModel(booster: Booster, summary: XGBoostTrainingSummary): M
+
+ private[spark] def getRuntimeParameters(isLocal: Boolean): RuntimeParams = {
+ val runOnGpu = if (getDevice != "cpu" || getTreeMethod == "gpu_hist") true else false
+ RuntimeParams(
+ getNumWorkers,
+ getNumRound,
+ TrackerConf(getRabitTrackerTimeout, getRabitTrackerHostIp, getRabitTrackerPort),
+ getNumEarlyStoppingRounds,
+ getDevice,
+ isLocal,
+ runOnGpu,
+ Option(getCustomObj),
+ Option(getCustomEval)
+ )
+ }
+
+ /**
+ * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
+ * If so, throw an exception unless this safety measure has been explicitly overridden
+ * via conf `xgboost.spark.ignoreSsl`.
+ */
+ private def validateSparkSslConf(spark: SparkSession): Unit = {
+
+ val sparkSslEnabled = spark.conf.getOption("spark.ssl.enabled").getOrElse("false").toBoolean
+ val xgbIgnoreSsl = spark.conf.getOption("xgboost.spark.ignoreSsl").getOrElse("false").toBoolean
+
+ if (sparkSslEnabled) {
+ if (xgbIgnoreSsl) {
+ logger.warn(s"spark-xgboost is being run without encrypting data in transit! " +
+ s"Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.")
+ } else {
+ throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data " +
+ "in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. " +
+ "To override this protection and still use xgboost-spark at your own risk, " +
+ "you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.")
+ }
+ }
+ }
+
+ /**
+ * Validate the parameters before training, throw exception if possible
+ */
+ protected[spark] def validate(dataset: Dataset[_]): Unit = {
+ validateSparkSslConf(dataset.sparkSession)
+ val schema = dataset.schema
+ SparkUtils.checkNumericType(schema, $(labelCol))
+ if (isDefinedNonEmpty(weightCol)) {
+ SparkUtils.checkNumericType(schema, $(weightCol))
+ }
+
+ if (isDefinedNonEmpty(baseMarginCol)) {
+ SparkUtils.checkNumericType(schema, $(baseMarginCol))
+ }
+
+ val taskCpus = dataset.sparkSession.sparkContext.getConf.getInt("spark.task.cpus", 1)
+ if (isDefined(nthread)) {
+ require(getNthread <= taskCpus,
+ s"the nthread configuration ($getNthread) must be no larger than " +
+ s"spark.task.cpus ($taskCpus)")
+ } else {
+ setNthread(taskCpus)
+ }
+ }
+
+ def train(dataset: Dataset[_]): M = {
+ validate(dataset)
+
+ val rdd = if (isPluginEnabled(dataset)) {
+ getPlugin.get.buildRddWatches(this, dataset)
+ } else {
+ val (input, columnIndexes) = preprocess(dataset)
+ toRdd(input, columnIndexes)
+ }
+
+ val xgbParams = getXGBoostParams
+
+ val runtimeParams = getRuntimeParameters(dataset.sparkSession.sparkContext.isLocal)
+
+ val (booster, metrics) = XGBoost.train(rdd, runtimeParams, xgbParams)
+
+ val summary = XGBoostTrainingSummary(metrics)
+ copyValues(createModel(booster, summary))
+ }
+
+ override def copy(extra: ParamMap): Learner = defaultCopy(extra).asInstanceOf[Learner]
+}
+
+/**
+ * Indicate what to be predicted
+ *
+ * @param predLeaf predicate leaf
+ * @param predContrib predicate contribution
+ * @param predRaw predicate raw
+ * @param predTmp predicate probability for classification, and raw for regression
+ */
+private[spark] case class PredictedColumns(
+ predLeaf: Boolean,
+ predContrib: Boolean,
+ predRaw: Boolean,
+ predTmp: Boolean)
+
+/**
+ * XGBoost base model
+ */
+private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with MLWritable
+ with XGBoostParams[M] with SparkParams[M] with ParamUtils[M] with PluginMixin {
+
+ protected val TMP_TRANSFORMED_COL = "_tmp_xgb_transformed_col"
+
+ override def copy(extra: ParamMap): M = defaultCopy(extra).asInstanceOf[M]
+
+ /**
+ * Get the native XGBoost Booster
+ *
+ * @return
+ */
+ def nativeBooster: Booster
+
+ def summary: Option[XGBoostTrainingSummary]
+
+ protected[spark] def postTransform(dataset: Dataset[_], pred: PredictedColumns): Dataset[_] = {
+ var output = dataset
+ // Convert leaf/contrib to the vector from array
+ if (pred.predLeaf) {
+ output = output.withColumn(getLeafPredictionCol,
+ array_to_vector(output.col(getLeafPredictionCol)))
+ }
+
+ if (pred.predContrib) {
+ output = output.withColumn(getContribPredictionCol,
+ array_to_vector(output.col(getContribPredictionCol)))
+ }
+ output
+ }
+
+ /**
+ * Preprocess the schema before transforming.
+ *
+ * @return the transformed schema and the
+ */
+ private[spark] def preprocess(dataset: Dataset[_]): (StructType, PredictedColumns) = {
+ // Be careful about the order of columns
+ var schema = dataset.schema
+
+ /** If the parameter is defined, add it to schema and turn true */
+ def addToSchema(param: Param[String], colName: Option[String] = None): Boolean = {
+ if (isDefinedNonEmpty(param)) {
+ val name = colName.getOrElse($(param))
+ schema = schema.add(StructField(name, ArrayType(FloatType)))
+ true
+ } else {
+ false
+ }
+ }
+
+ val predLeaf = addToSchema(leafPredictionCol)
+ val predContrib = addToSchema(contribPredictionCol)
+
+ var predRaw = false
+ // For classification case, the tranformed col is probability,
+ // while for others, it's the prediction value.
+ var predTmp = false
+ this match {
+ case p: XGBProbabilisticClassifierParams[_] => // classification case
+ predRaw = addToSchema(p.rawPredictionCol)
+ predTmp = addToSchema(p.probabilityCol, Some(TMP_TRANSFORMED_COL))
+
+ if (isDefinedNonEmpty(predictionCol)) {
+ // Let's use transformed col to calculate the prediction
+ if (!predTmp) {
+ // Add the transformed col for predition
+ schema = schema.add(
+ StructField(TMP_TRANSFORMED_COL, ArrayType(FloatType)))
+ predTmp = true
+ }
+ }
+ case _ =>
+ // Rename TMP_TRANSFORMED_COL to prediction in the postTransform.
+ predTmp = addToSchema(predictionCol, Some(TMP_TRANSFORMED_COL))
+ }
+ (schema, PredictedColumns(predLeaf, predContrib, predRaw, predTmp))
+ }
+
+ /** Predict */
+ private[spark] def predictInternal(booster: Booster, dm: DMatrix, pred: PredictedColumns,
+ batchRow: Iterator[Row]): Seq[Row] = {
+ var tmpOut = batchRow.toSeq.map(_.toSeq)
+ val zip = (left: Seq[Seq[_]], right: Array[Array[Float]]) => left.zip(right).map {
+ case (a, b) => a ++ Seq(b)
+ }
+ if (pred.predLeaf) {
+ tmpOut = zip(tmpOut, booster.predictLeaf(dm))
+ }
+ if (pred.predContrib) {
+ tmpOut = zip(tmpOut, booster.predictContrib(dm))
+ }
+ if (pred.predRaw) {
+ tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = true))
+ }
+ if (pred.predTmp) {
+ tmpOut = zip(tmpOut, booster.predict(dm, outPutMargin = false))
+ }
+ tmpOut.map(Row.fromSeq)
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+
+ if (getPlugin.isDefined) {
+ return getPlugin.get.transform(this, dataset)
+ }
+
+ val (schema, pred) = preprocess(dataset)
+ val bBooster = dataset.sparkSession.sparkContext.broadcast(nativeBooster)
+ // TODO configurable
+ val inferBatchSize = 32 << 10
+ // Broadcast the booster to each executor.
+ val featureName = getFeaturesCol
+ val missing = getMissing
+
+ val output = dataset.toDF().mapPartitions { rowIter =>
+ rowIter.grouped(inferBatchSize).flatMap { batchRow =>
+ val features = batchRow.iterator.map(row => row.getAs[Vector](
+ row.fieldIndex(featureName)))
+ // DMatrix used to prediction
+ val dm = new DMatrix(features.map(_.asXGB), null, missing)
+ try {
+ predictInternal(bBooster.value, dm, pred, batchRow.toIterator)
+ } finally {
+ dm.delete()
+ }
+ }
+
+ }(Encoders.row(schema))
+ bBooster.unpersist(blocking = false)
+ postTransform(output, pred).toDF()
+ }
+
+ override def write: MLWriter = new XGBoostModelWriter(this)
+
+ protected def predictSingleInstance(features: Vector): Array[Float] = {
+ if (nativeBooster == null) {
+ throw new IllegalArgumentException("The model has not been trained")
+ }
+ val dm = new DMatrix(Iterator(features.asXGB), null, getMissing)
+ nativeBooster.predict(data = dm)(0)
+ }
+}
+
+/**
+ * Class to write the model
+ *
+ * @param instance model to be written
+ */
+private[spark] class XGBoostModelWriter(instance: XGBoostModel[_]) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ if (Option(instance.nativeBooster).isEmpty) {
+ throw new RuntimeException("The XGBoost model has not been trained")
+ }
+ SparkUtils.saveMetadata(instance, path, sc)
+
+ // Save model data
+ val dataPath = new Path(path, "data").toString
+ val internalPath = new Path(dataPath, "model")
+ val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
+ val format = optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
+ try {
+ instance.nativeBooster.saveModel(outputStream, format)
+ } finally {
+ outputStream.close()
+ }
+ }
+}
+
+private[spark] abstract class XGBoostModelReader[M <: XGBoostModel[M]] extends MLReader[M] {
+
+ protected def loadBooster(path: String): Booster = {
+ val dataPath = new Path(path, "data").toString
+ val internalPath = new Path(dataPath, "model")
+ val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
+ try {
+ SXGBoost.loadModel(dataInStream)
+ } finally {
+ dataInStream.close()
+ }
+ }
+}
+
+// Trait for Ranker and Regressor Model
+private[spark] trait RankerRegressorBaseModel[M <: XGBoostModel[M]] extends XGBoostModel[M] {
+
+ override protected[spark] def postTransform(dataset: Dataset[_],
+ pred: PredictedColumns): Dataset[_] = {
+ var output = super.postTransform(dataset, pred)
+ if (isDefinedNonEmpty(predictionCol) && pred.predTmp) {
+ val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
+ originalPrediction(0).toDouble
+ }
+ output = output
+ .withColumn($(predictionCol), predictUDF(col(TMP_TRANSFORMED_COL)))
+ .drop(TMP_TRANSFORMED_COL)
+ }
+ output
+ }
+
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala
new file mode 100644
index 000000000000..dda82f97968b
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala
@@ -0,0 +1,49 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.Serializable
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+trait XGBoostPlugin extends Serializable {
+ /**
+ * Whether the plugin is enabled or not, if not enabled, fallback
+ * to the regular CPU pipeline
+ *
+ * @param dataset the input dataset
+ * @return Boolean
+ */
+ def isEnabled(dataset: Dataset[_]): Boolean
+
+ /**
+ * Convert Dataset to RDD[Watches] which will be fed into XGBoost
+ *
+ * @param estimator which estimator to be handled.
+ * @param dataset to be converted.
+ * @return RDD[Watches]
+ */
+ def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
+ estimator: XGBoostEstimator[T, M],
+ dataset: Dataset[_]): RDD[Watches]
+
+ /**
+ * Transform the dataset
+ */
+ def transform[M <: XGBoostModel[M]](model: XGBoostModel[M], dataset: Dataset[_]): DataFrame
+
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala
new file mode 100644
index 000000000000..0744f2de9702
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala
@@ -0,0 +1,120 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import org.apache.spark.ml.{PredictionModel, Predictor}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
+import org.apache.spark.ml.xgboost.SparkUtils
+import org.apache.spark.sql.Dataset
+
+import ml.dmlc.xgboost4j.scala.Booster
+import ml.dmlc.xgboost4j.scala.spark.XGBoostRanker._uid
+import ml.dmlc.xgboost4j.scala.spark.params.HasGroupCol
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
+
+class XGBoostRanker(override val uid: String,
+ private val xgboostParams: Map[String, Any])
+ extends Predictor[Vector, XGBoostRanker, XGBoostRankerModel]
+ with XGBoostEstimator[XGBoostRanker, XGBoostRankerModel] with HasGroupCol {
+
+ def this() = this(_uid, Map[String, Any]())
+
+ def this(uid: String) = this(uid, Map[String, Any]())
+
+ def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
+
+ def setGroupCol(value: String): XGBoostRanker = set(groupCol, value)
+
+ xgboost2SparkParams(xgboostParams)
+
+ /**
+ * Validate the parameters before training, throw exception if possible
+ */
+ override protected[spark] def validate(dataset: Dataset[_]): Unit = {
+ super.validate(dataset)
+
+ // If the objective is set explicitly, it must be in binaryClassificationObjs and
+ // multiClassificationObjs
+ if (isSet(objective)) {
+ val tmpObj = getObjective
+ require(RANKER_OBJS.contains(tmpObj),
+ s"Wrong objective for XGBoostRanker, supported objs: ${RANKER_OBJS.mkString(",")}")
+ } else {
+ setObjective("rank:ndcg")
+ }
+
+ require(isDefinedNonEmpty(groupCol), "groupCol needs to be set")
+ }
+
+ /**
+ * Preprocess the dataset to meet the xgboost input requirement
+ *
+ * @param dataset
+ * @return
+ */
+ override private[spark] def preprocess(dataset: Dataset[_]): (Dataset[_], ColumnIndices) = {
+ val (output, columnIndices) = super.preprocess(dataset)
+ (output.sortWithinPartitions(getGroupCol), columnIndices)
+ }
+
+ override protected def createModel(
+ booster: Booster,
+ summary: XGBoostTrainingSummary): XGBoostRankerModel = {
+ new XGBoostRankerModel(uid, booster, Option(summary))
+ }
+}
+
+object XGBoostRanker extends DefaultParamsReadable[XGBoostRanker] {
+ private val _uid = Identifiable.randomUID("xgbranker")
+}
+
+class XGBoostRankerModel private[ml](val uid: String,
+ val nativeBooster: Booster,
+ val summary: Option[XGBoostTrainingSummary] = None)
+ extends PredictionModel[Vector, XGBoostRankerModel]
+ with RankerRegressorBaseModel[XGBoostRankerModel] with HasGroupCol {
+
+ def this(uid: String) = this(uid, null)
+
+ def setGroupCol(value: String): XGBoostRankerModel = set(groupCol, value)
+
+ override def copy(extra: ParamMap): XGBoostRankerModel = {
+ val newModel = copyValues(new XGBoostRankerModel(uid, nativeBooster, summary), extra)
+ newModel.setParent(parent)
+ }
+
+ override def predict(features: Vector): Double = {
+ val values = predictSingleInstance(features)
+ values(0)
+ }
+}
+
+object XGBoostRankerModel extends MLReadable[XGBoostRankerModel] {
+ override def read: MLReader[XGBoostRankerModel] = new ModelReader
+
+ private class ModelReader extends XGBoostModelReader[XGBoostRankerModel] {
+ override def load(path: String): XGBoostRankerModel = {
+ val xgbModel = loadBooster(path)
+ val meta = SparkUtils.loadMetadata(path, sc)
+ val model = new XGBoostRankerModel(meta.uid, xgbModel, None)
+ meta.getAndSetParams(model)
+ model
+ }
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
index 986e04c6b047..9c20a499b93a 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,405 +16,84 @@
package ml.dmlc.xgboost4j.scala.spark
-import scala.collection.{Iterator, mutable}
-
-import ml.dmlc.xgboost4j.scala.spark.params._
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
-import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
-import org.apache.hadoop.fs.Path
-
+import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.linalg.Vector
-import org.apache.spark.ml.util._
-import org.apache.spark.ml._
-import org.apache.spark.ml.param._
-import org.apache.spark.sql._
-import org.apache.spark.sql.functions._
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader}
+import org.apache.spark.ml.xgboost.SparkUtils
+import org.apache.spark.sql.Dataset
-import org.apache.spark.ml.util.{DefaultXGBoostParamsReader, DefaultXGBoostParamsWriter, XGBoostWriter}
-import org.apache.spark.sql.types.StructType
+import ml.dmlc.xgboost4j.scala.Booster
+import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor._uid
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
-class XGBoostRegressor (
- override val uid: String,
- private val xgboostParams: Map[String, Any])
+class XGBoostRegressor(override val uid: String,
+ private val xgboostParams: Map[String, Any])
extends Predictor[Vector, XGBoostRegressor, XGBoostRegressionModel]
- with XGBoostRegressorParams with DefaultParamsWritable {
+ with XGBoostEstimator[XGBoostRegressor, XGBoostRegressionModel] {
- def this() = this(Identifiable.randomUID("xgbr"), Map[String, Any]())
+ def this() = this(_uid, Map[String, Any]())
def this(uid: String) = this(uid, Map[String, Any]())
- def this(xgboostParams: Map[String, Any]) = this(
- Identifiable.randomUID("xgbr"), xgboostParams)
-
- XGBoost2MLlibParams(xgboostParams)
-
- def setWeightCol(value: String): this.type = set(weightCol, value)
-
- def setBaseMarginCol(value: String): this.type = set(baseMarginCol, value)
-
- def setGroupCol(value: String): this.type = set(groupCol, value)
-
- // setters for general params
- def setNumRound(value: Int): this.type = set(numRound, value)
-
- def setNumWorkers(value: Int): this.type = set(numWorkers, value)
-
- def setNthread(value: Int): this.type = set(nthread, value)
-
- def setUseExternalMemory(value: Boolean): this.type = set(useExternalMemory, value)
-
- def setSilent(value: Int): this.type = set(silent, value)
-
- def setMissing(value: Float): this.type = set(missing, value)
-
- def setCheckpointPath(value: String): this.type = set(checkpointPath, value)
-
- def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
-
- def setSeed(value: Long): this.type = set(seed, value)
-
- def setEta(value: Double): this.type = set(eta, value)
-
- def setGamma(value: Double): this.type = set(gamma, value)
-
- def setMaxDepth(value: Int): this.type = set(maxDepth, value)
-
- def setMinChildWeight(value: Double): this.type = set(minChildWeight, value)
-
- def setMaxDeltaStep(value: Double): this.type = set(maxDeltaStep, value)
-
- def setSubsample(value: Double): this.type = set(subsample, value)
-
- def setColsampleBytree(value: Double): this.type = set(colsampleBytree, value)
-
- def setColsampleBylevel(value: Double): this.type = set(colsampleBylevel, value)
-
- def setLambda(value: Double): this.type = set(lambda, value)
-
- def setAlpha(value: Double): this.type = set(alpha, value)
-
- def setTreeMethod(value: String): this.type = set(treeMethod, value)
-
- def setDevice(value: String): this.type = set(device, value)
-
- def setGrowPolicy(value: String): this.type = set(growPolicy, value)
-
- def setMaxBins(value: Int): this.type = set(maxBins, value)
-
- def setMaxLeaves(value: Int): this.type = set(maxLeaves, value)
-
- def setScalePosWeight(value: Double): this.type = set(scalePosWeight, value)
+ def this(xgboostParams: Map[String, Any]) = this(_uid, xgboostParams)
- def setSampleType(value: String): this.type = set(sampleType, value)
+ xgboost2SparkParams(xgboostParams)
- def setNormalizeType(value: String): this.type = set(normalizeType, value)
-
- def setRateDrop(value: Double): this.type = set(rateDrop, value)
-
- def setSkipDrop(value: Double): this.type = set(skipDrop, value)
-
- def setLambdaBias(value: Double): this.type = set(lambdaBias, value)
-
- // setters for learning params
- def setObjective(value: String): this.type = set(objective, value)
-
- def setObjectiveType(value: String): this.type = set(objectiveType, value)
-
- def setBaseScore(value: Double): this.type = set(baseScore, value)
-
- def setEvalMetric(value: String): this.type = set(evalMetric, value)
-
- def setTrainTestRatio(value: Double): this.type = set(trainTestRatio, value)
-
- def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
-
- def setMaximizeEvaluationMetrics(value: Boolean): this.type =
- set(maximizeEvaluationMetrics, value)
-
- def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
-
- def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
-
- def setAllowNonZeroForMissing(value: Boolean): this.type = set(
- allowNonZeroForMissing,
- value
- )
-
- def setSinglePrecisionHistogram(value: Boolean): this.type =
- set(singlePrecisionHistogram, value)
-
- def setFeatureNames(value: Array[String]): this.type =
- set(featureNames, value)
-
- def setFeatureTypes(value: Array[String]): this.type =
- set(featureTypes, value)
-
- // called at the start of fit/train when 'eval_metric' is not defined
- private def setupDefaultEvalMetric(): String = {
- require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
- if ($(objective).startsWith("rank")) {
- "map"
- } else {
- "rmse"
- }
- }
-
- private[spark] def transformSchemaInternal(schema: StructType): StructType = {
- if (isFeaturesColSet(schema)) {
- // User has vectorized the features into VectorUDT.
- super.transformSchema(schema)
- } else {
- transformSchemaWithFeaturesCols(false, schema)
+ /**
+ * Validate the parameters before training, throw exception if possible
+ */
+ override protected[spark] def validate(dataset: Dataset[_]): Unit = {
+ super.validate(dataset)
+
+ // If the objective is set explicitly, it must be in binaryClassificationObjs and
+ // multiClassificationObjs
+ if (isSet(objective)) {
+ val tmpObj = getObjective
+ require(REGRESSION_OBJS.contains(tmpObj),
+ s"Wrong objective for XGBoostRegressor, supported objs: ${REGRESSION_OBJS.mkString(",")}")
}
}
- override def transformSchema(schema: StructType): StructType = {
- PreXGBoost.transformSchema(this, schema)
- }
-
- override protected def train(dataset: Dataset[_]): XGBoostRegressionModel = {
-
- if (!isDefined(objective)) {
- // If user doesn't set objective, force it to reg:squarederror
- setObjective("reg:squarederror")
- }
-
- if (!isDefined(evalMetric) || $(evalMetric).isEmpty) {
- set(evalMetric, setupDefaultEvalMetric())
- }
-
- if (isDefined(customObj) && $(customObj) != null) {
- set(objectiveType, "regression")
- }
-
- transformSchema(dataset.schema, logging = true)
-
- // Packing with all params plus params user defined
- val derivedXGBParamMap = xgboostParams ++ MLlib2XGBoostParams
- val buildTrainingData = PreXGBoost.buildDatasetToRDD(this, dataset, derivedXGBParamMap)
-
- // All non-null param maps in XGBoostRegressor are in derivedXGBParamMap.
- val (_booster, _metrics) = XGBoost.trainDistributed(dataset.sparkSession.sparkContext,
- buildTrainingData, derivedXGBParamMap)
-
- val model = new XGBoostRegressionModel(uid, _booster)
- val summary = XGBoostTrainingSummary(_metrics)
- model.setSummary(summary)
- model
+ override protected def createModel(
+ booster: Booster,
+ summary: XGBoostTrainingSummary): XGBoostRegressionModel = {
+ new XGBoostRegressionModel(uid, booster, Option(summary))
}
-
- override def copy(extra: ParamMap): XGBoostRegressor = defaultCopy(extra)
}
object XGBoostRegressor extends DefaultParamsReadable[XGBoostRegressor] {
-
- override def load(path: String): XGBoostRegressor = super.load(path)
+ private val _uid = Identifiable.randomUID("xgbr")
}
-class XGBoostRegressionModel private[ml] (
- override val uid: String,
- private[scala] val _booster: Booster)
+class XGBoostRegressionModel private[ml](val uid: String,
+ val nativeBooster: Booster,
+ val summary: Option[XGBoostTrainingSummary] = None)
extends PredictionModel[Vector, XGBoostRegressionModel]
- with XGBoostRegressorParams with InferenceParams
- with MLWritable with Serializable {
+ with RankerRegressorBaseModel[XGBoostRegressionModel] {
- import XGBoostRegressionModel._
-
- // only called in copy()
def this(uid: String) = this(uid, null)
- /**
- * Get the native booster instance of this model.
- * This is used to call low-level APIs on native booster, such as "getFeatureScore".
- */
- def nativeBooster: Booster = _booster
-
- private var trainingSummary: Option[XGBoostTrainingSummary] = None
-
- /**
- * Returns summary (e.g. train/test objective history) of model on the
- * training set. An exception is thrown if no summary is available.
- */
- def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
- throw new IllegalStateException("No training summary available for this XGBoostModel")
- }
-
- private[spark] def setSummary(summary: XGBoostTrainingSummary): this.type = {
- trainingSummary = Some(summary)
- this
+ override def copy(extra: ParamMap): XGBoostRegressionModel = {
+ val newModel = copyValues(new XGBoostRegressionModel(uid, nativeBooster, summary), extra)
+ newModel.setParent(parent)
}
- def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
-
- def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
-
- def setTreeLimit(value: Int): this.type = set(treeLimit, value)
-
- def setMissing(value: Float): this.type = set(missing, value)
-
- def setAllowNonZeroForMissing(value: Boolean): this.type = set(
- allowNonZeroForMissing,
- value
- )
-
- def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
-
- /**
- * Single instance prediction.
- * Note: The performance is not ideal, use it carefully!
- */
override def predict(features: Vector): Double = {
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
- val dm = new DMatrix(processMissingValues(
- Iterator(features.asXGB),
- $(missing),
- $(allowNonZeroForMissing)
- ))
- _booster.predict(data = dm)(0)(0)
+ val values = predictSingleInstance(features)
+ values(0)
}
-
- private[scala] def produceResultIterator(
- originalRowItr: Iterator[Row],
- predictionItr: Iterator[Row],
- predLeafItr: Iterator[Row],
- predContribItr: Iterator[Row]): Iterator[Row] = {
- // the following implementation is to be improved
- if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
- isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
- originalRowItr.zip(predictionItr).zip(predLeafItr).zip(predContribItr).
- map { case (((originals: Row, prediction: Row), leaves: Row), contribs: Row) =>
- Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq ++ contribs.toSeq)
- }
- } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
- (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
- originalRowItr.zip(predictionItr).zip(predLeafItr).
- map { case ((originals: Row, prediction: Row), leaves: Row) =>
- Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ leaves.toSeq)
- }
- } else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
- isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
- originalRowItr.zip(predictionItr).zip(predContribItr).
- map { case ((originals: Row, prediction: Row), contribs: Row) =>
- Row.fromSeq(originals.toSeq ++ prediction.toSeq ++ contribs.toSeq)
- }
- } else {
- originalRowItr.zip(predictionItr).map {
- case (originals: Row, originalPrediction: Row) =>
- Row.fromSeq(originals.toSeq ++ originalPrediction.toSeq)
- }
- }
- }
-
- private[scala] def producePredictionItrs(booster: Booster, dm: DMatrix):
- Array[Iterator[Row]] = {
- val originalPredictionItr = {
- booster.predict(dm, outPutMargin = false, $(treeLimit)).map(Row(_)).iterator
- }
- val predLeafItr = {
- if (isDefined(leafPredictionCol)) {
- booster.predictLeaf(dm, $(treeLimit)).
- map(Row(_)).iterator
- } else {
- Iterator()
- }
- }
- val predContribItr = {
- if (isDefined(contribPredictionCol)) {
- booster.predictContrib(dm, $(treeLimit)).
- map(Row(_)).iterator
- } else {
- Iterator()
- }
- }
- Array(originalPredictionItr, predLeafItr, predContribItr)
- }
-
- private[spark] def transformSchemaInternal(schema: StructType): StructType = {
- if (isFeaturesColSet(schema)) {
- // User has vectorized the features into VectorUDT.
- super.transformSchema(schema)
- } else {
- transformSchemaWithFeaturesCols(false, schema)
- }
- }
-
- override def transformSchema(schema: StructType): StructType = {
- PreXGBoost.transformSchema(this, schema)
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
- // Output selected columns only.
- // This is a bit complicated since it tries to avoid repeated computation.
- var outputData = PreXGBoost.transformDataset(this, dataset)
- var numColsOutput = 0
-
- val predictUDF = udf { (originalPrediction: mutable.WrappedArray[Float]) =>
- originalPrediction(0).toDouble
- }
-
- if ($(predictionCol).nonEmpty) {
- outputData = outputData
- .withColumn($(predictionCol), predictUDF(col(_originalPredictionCol)))
- numColsOutput += 1
- }
-
- if (numColsOutput == 0) {
- this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
- " since no output columns were set.")
- }
- outputData.toDF.drop(col(_originalPredictionCol))
- }
-
- override def copy(extra: ParamMap): XGBoostRegressionModel = {
- val newModel = copyValues(new XGBoostRegressionModel(uid, _booster), extra)
- newModel.setSummary(summary).setParent(parent)
- }
-
- override def write: MLWriter =
- new XGBoostRegressionModel.XGBoostRegressionModelWriter(this)
}
object XGBoostRegressionModel extends MLReadable[XGBoostRegressionModel] {
+ override def read: MLReader[XGBoostRegressionModel] = new ModelReader
- private[scala] val _originalPredictionCol = "_originalPrediction"
-
- override def read: MLReader[XGBoostRegressionModel] = new XGBoostRegressionModelReader
-
- override def load(path: String): XGBoostRegressionModel = super.load(path)
-
- private[XGBoostRegressionModel]
- class XGBoostRegressionModelWriter(instance: XGBoostRegressionModel) extends XGBoostWriter {
-
- override protected def saveImpl(path: String): Unit = {
- // Save metadata and Params
- DefaultXGBoostParamsWriter.saveMetadata(instance, path, sc)
- // Save model data
- val dataPath = new Path(path, "data").toString
- val internalPath = new Path(dataPath, "XGBoostRegressionModel")
- val outputStream = internalPath.getFileSystem(sc.hadoopConfiguration).create(internalPath)
- instance._booster.saveModel(outputStream, getModelFormat())
- outputStream.close()
- }
- }
-
- private class XGBoostRegressionModelReader extends MLReader[XGBoostRegressionModel] {
-
- /** Checked against metadata when loading model */
- private val className = classOf[XGBoostRegressionModel].getName
-
+ private class ModelReader extends XGBoostModelReader[XGBoostRegressionModel] {
override def load(path: String): XGBoostRegressionModel = {
- implicit val sc = super.sparkSession.sparkContext
-
- val metadata = DefaultXGBoostParamsReader.loadMetadata(path, sc, className)
-
- val dataPath = new Path(path, "data").toString
- val internalPath = new Path(dataPath, "XGBoostRegressionModel")
- val dataInStream = internalPath.getFileSystem(sc.hadoopConfiguration).open(internalPath)
-
- val booster = SXGBoost.loadModel(dataInStream)
- val model = new XGBoostRegressionModel(metadata.uid, booster)
- DefaultXGBoostParamsReader.getAndSetParams(model, metadata)
+ val xgbModel = loadBooster(path)
+ val meta = SparkUtils.loadMetadata(path, sc)
+ val model = new XGBoostRegressionModel(meta.uid, xgbModel, None)
+ meta.getAndSetParams(model)
model
}
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostTrainingSummary.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostTrainingSummary.scala
index 9454befc2fdc..de62feb2601f 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostTrainingSummary.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostTrainingSummary.scala
@@ -22,17 +22,17 @@ class XGBoostTrainingSummary private(
override def toString: String = {
val train = trainObjectiveHistory.mkString(",")
- val vaidationObjectiveHistoryString = {
+ val validationObjectiveHistoryString = {
validationObjectiveHistory.map {
case (name, metrics) =>
s"${name}ObjectiveHistory=${metrics.mkString(",")}"
}.mkString(";")
}
- s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $vaidationObjectiveHistoryString)"
+ s"XGBoostTrainingSummary(trainObjectiveHistory=$train; $validationObjectiveHistoryString)"
}
}
-private[xgboost4j] object XGBoostTrainingSummary {
+private[spark] object XGBoostTrainingSummary {
def apply(metrics: Map[String, Array[Float]]): XGBoostTrainingSummary = {
new XGBoostTrainingSummary(
trainObjectiveHistory = metrics("train"),
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala
deleted file mode 100644
index b64ad9385a9b..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala
+++ /dev/null
@@ -1,295 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark.params
-
-import scala.collection.immutable.HashSet
-
-import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
-
-private[spark] trait BoosterParams extends Params {
-
- /**
- * step size shrinkage used in update to prevents overfitting. After each boosting step, we
- * can directly get the weights of new features and eta actually shrinks the feature weights
- * to make the boosting process more conservative. [default=0.3] range: [0,1]
- */
- final val eta = new DoubleParam(this, "eta", "step size shrinkage used in update to prevents" +
- " overfitting. After each boosting step, we can directly get the weights of new features." +
- " and eta actually shrinks the feature weights to make the boosting process more conservative.",
- (value: Double) => value >= 0 && value <= 1)
-
- final def getEta: Double = $(eta)
-
- /**
- * minimum loss reduction required to make a further partition on a leaf node of the tree.
- * the larger, the more conservative the algorithm will be. [default=0] range: [0,
- * Double.MaxValue]
- */
- final val gamma = new DoubleParam(this, "gamma", "minimum loss reduction required to make a " +
- "further partition on a leaf node of the tree. the larger, the more conservative the " +
- "algorithm will be.", (value: Double) => value >= 0)
-
- final def getGamma: Double = $(gamma)
-
- /**
- * maximum depth of a tree, increase this value will make model more complex / likely to be
- * overfitting. [default=6] range: [1, Int.MaxValue]
- */
- final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
- "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
-
- final def getMaxDepth: Int = $(maxDepth)
-
-
- /**
- * Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
- */
- final val maxLeaves = new IntParam(this, "maxLeaves",
- "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
- (value: Int) => value >= 0)
-
- final def getMaxLeaves: Int = $(maxLeaves)
-
-
- /**
- * minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
- * in a leaf node with the sum of instance weight less than min_child_weight, then the building
- * process will give up further partitioning. In linear regression mode, this simply corresponds
- * to minimum number of instances needed to be in each node. The larger, the more conservative
- * the algorithm will be. [default=1] range: [0, Double.MaxValue]
- */
- final val minChildWeight = new DoubleParam(this, "minChildWeight", "minimum sum of instance" +
- " weight(hessian) needed in a child. If the tree partition step results in a leaf node with" +
- " the sum of instance weight less than min_child_weight, then the building process will" +
- " give up further partitioning. In linear regression mode, this simply corresponds to minimum" +
- " number of instances needed to be in each node. The larger, the more conservative" +
- " the algorithm will be.", (value: Double) => value >= 0)
-
- final def getMinChildWeight: Double = $(minChildWeight)
-
- /**
- * Maximum delta step we allow each tree's weight estimation to be. If the value is set to 0, it
- * means there is no constraint. If it is set to a positive value, it can help making the update
- * step more conservative. Usually this parameter is not needed, but it might help in logistic
- * regression when class is extremely imbalanced. Set it to value of 1-10 might help control the
- * update. [default=0] range: [0, Double.MaxValue]
- */
- final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "Maximum delta step we allow " +
- "each tree's weight" +
- " estimation to be. If the value is set to 0, it means there is no constraint. If it is set" +
- " to a positive value, it can help making the update step more conservative. Usually this" +
- " parameter is not needed, but it might help in logistic regression when class is extremely" +
- " imbalanced. Set it to value of 1-10 might help control the update",
- (value: Double) => value >= 0)
-
- final def getMaxDeltaStep: Double = $(maxDeltaStep)
-
- /**
- * subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly
- * collected half of the data instances to grow trees and this will prevent overfitting.
- * [default=1] range:(0,1]
- */
- final val subsample = new DoubleParam(this, "subsample", "subsample ratio of the training " +
- "instance. Setting it to 0.5 means that XGBoost randomly collected half of the data " +
- "instances to grow trees and this will prevent overfitting.",
- (value: Double) => value <= 1 && value > 0)
-
- final def getSubsample: Double = $(subsample)
-
- /**
- * subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
- */
- final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "subsample ratio of " +
- "columns when constructing each tree.", (value: Double) => value <= 1 && value > 0)
-
- final def getColsampleBytree: Double = $(colsampleBytree)
-
- /**
- * subsample ratio of columns for each split, in each level. [default=1] range: (0,1]
- */
- final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "subsample ratio of " +
- "columns for each split, in each level.", (value: Double) => value <= 1 && value > 0)
-
- final def getColsampleBylevel: Double = $(colsampleBylevel)
-
- /**
- * L2 regularization term on weights, increase this value will make model more conservative.
- * [default=1]
- */
- final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights, " +
- "increase this value will make model more conservative.", (value: Double) => value >= 0)
-
- final def getLambda: Double = $(lambda)
-
- /**
- * L1 regularization term on weights, increase this value will make model more conservative.
- * [default=0]
- */
- final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights, increase " +
- "this value will make model more conservative.", (value: Double) => value >= 0)
-
- final def getAlpha: Double = $(alpha)
-
- /**
- * The tree construction algorithm used in XGBoost. options:
- * {'auto', 'exact', 'approx','gpu_hist'} [default='auto']
- */
- final val treeMethod = new Param[String](this, "treeMethod",
- "The tree construction algorithm used in XGBoost, options: " +
- "{'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
- (value: String) => BoosterParams.supportedTreeMethods.contains(value))
-
- final def getTreeMethod: String = $(treeMethod)
-
- /**
- * The device for running XGBoost algorithms, options: cpu, cuda
- */
- final val device = new Param[String](
- this, "device", "The device for running XGBoost algorithms, options: cpu, cuda",
- (value: String) => BoosterParams.supportedDevices.contains(value)
- )
-
- final def getDevice: String = $(device)
-
- /**
- * growth policy for fast histogram algorithm
- */
- final val growPolicy = new Param[String](this, "growPolicy",
- "Controls a way new nodes are added to the tree. Currently supported only if" +
- " tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
- " closest to the root. lossguide: split at nodes with highest loss change.",
- (value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
-
- final def getGrowPolicy: String = $(growPolicy)
-
- /**
- * maximum number of bins in histogram
- */
- final val maxBins = new IntParam(this, "maxBin", "maximum number of bins in histogram",
- (value: Int) => value > 0)
-
- final def getMaxBins: Int = $(maxBins)
-
- /**
- * whether to build histograms using single precision floating point values
- */
- final val singlePrecisionHistogram = new BooleanParam(this, "singlePrecisionHistogram",
- "whether to use single precision to build histograms")
-
- final def getSinglePrecisionHistogram: Boolean = $(singlePrecisionHistogram)
-
- /**
- * Control the balance of positive and negative weights, useful for unbalanced classes. A typical
- * value to consider: sum(negative cases) / sum(positive cases). [default=1]
- */
- final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "Control the balance of " +
- "positive and negative weights, useful for unbalanced classes. A typical value to consider:" +
- " sum(negative cases) / sum(positive cases)")
-
- final def getScalePosWeight: Double = $(scalePosWeight)
-
- // Dart boosters
-
- /**
- * Parameter for Dart booster.
- * Type of sampling algorithm. "uniform": dropped trees are selected uniformly.
- * "weighted": dropped trees are selected in proportion to weight. [default="uniform"]
- */
- final val sampleType = new Param[String](this, "sampleType", "type of sampling algorithm, " +
- "options: {'uniform', 'weighted'}",
- (value: String) => BoosterParams.supportedSampleType.contains(value))
-
- final def getSampleType: String = $(sampleType)
-
- /**
- * Parameter of Dart booster.
- * type of normalization algorithm, options: {'tree', 'forest'}. [default="tree"]
- */
- final val normalizeType = new Param[String](this, "normalizeType", "type of normalization" +
- " algorithm, options: {'tree', 'forest'}",
- (value: String) => BoosterParams.supportedNormalizeType.contains(value))
-
- final def getNormalizeType: String = $(normalizeType)
-
- /**
- * Parameter of Dart booster.
- * dropout rate. [default=0.0] range: [0.0, 1.0]
- */
- final val rateDrop = new DoubleParam(this, "rateDrop", "dropout rate", (value: Double) =>
- value >= 0 && value <= 1)
-
- final def getRateDrop: Double = $(rateDrop)
-
- /**
- * Parameter of Dart booster.
- * probability of skip dropout. If a dropout is skipped, new trees are added in the same manner
- * as gbtree. [default=0.0] range: [0.0, 1.0]
- */
- final val skipDrop = new DoubleParam(this, "skipDrop", "probability of skip dropout. If" +
- " a dropout is skipped, new trees are added in the same manner as gbtree.",
- (value: Double) => value >= 0 && value <= 1)
-
- final def getSkipDrop: Double = $(skipDrop)
-
- // linear booster
- /**
- * Parameter of linear booster
- * L2 regularization term on bias, default 0(no L1 reg on bias because it is not important)
- */
- final val lambdaBias = new DoubleParam(this, "lambdaBias", "L2 regularization term on bias, " +
- "default 0 (no L1 reg on bias because it is not important)", (value: Double) => value >= 0)
-
- final def getLambdaBias: Double = $(lambdaBias)
-
- final val treeLimit = new IntParam(this, name = "treeLimit",
- doc = "number of trees used in the prediction; defaults to 0 (use all trees).")
- setDefault(treeLimit, 0)
-
- final def getTreeLimit: Int = $(treeLimit)
-
- final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
- doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
- "decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
- "padded ")
-
- final def getMonotoneConstraints: String = $(monotoneConstraints)
-
- final val interactionConstraints = new Param[String](this,
- name = "interactionConstraints",
- doc = "Constraints for interaction representing permitted interactions. The constraints" +
- " must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
- " where each inner list is a group of indices of features that are allowed to interact" +
- " with each other. See tutorial for more information")
-
- final def getInteractionConstraints: String = $(interactionConstraints)
-
-}
-
-private[scala] object BoosterParams {
-
- val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
-
- val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
-
- val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
-
- val supportedSampleType = HashSet("uniform", "weighted")
-
- val supportedNormalizeType = HashSet("tree", "forest")
-
- val supportedDevices = HashSet("cpu", "cuda")
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala
index f838baac2c9c..2f1cb21b0f1e 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/CustomParams.scala
@@ -16,22 +16,20 @@
package ml.dmlc.xgboost4j.scala.spark.params
-import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
-import ml.dmlc.xgboost4j.scala.spark.TrackerConf
-import ml.dmlc.xgboost4j.scala.spark.util.Utils
-
import org.apache.spark.ml.param.{Param, ParamPair, Params}
-import org.json4s.{DefaultFormats, Extraction, NoTypeHints}
+import org.json4s.{DefaultFormats, Extraction}
import org.json4s.jackson.JsonMethods.{compact, parse, render}
import org.json4s.jackson.Serialization
+import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
+import ml.dmlc.xgboost4j.scala.spark.Utils
+
/**
* General spark parameter that includes TypeHints for (de)serialization using json4s.
*/
-class CustomGeneralParam[T: Manifest](
- parent: Params,
- name: String,
- doc: String) extends Param[T](parent, name, doc) {
+class CustomGeneralParam[T: Manifest](parent: Params,
+ name: String,
+ doc: String) extends Param[T](parent, name, doc) {
/** Creates a param pair with the given value (for Java). */
override def w(value: T): ParamPair[T] = super.w(value)
@@ -52,33 +50,10 @@ class CustomGeneralParam[T: Manifest](
}
}
-class CustomEvalParam(
- parent: Params,
- name: String,
- doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
+class CustomEvalParam(parent: Params,
+ name: String,
+ doc: String) extends CustomGeneralParam[EvalTrait](parent, name, doc)
-class CustomObjParam(
- parent: Params,
- name: String,
- doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
-
-class TrackerConfParam(
- parent: Params,
- name: String,
- doc: String) extends Param[TrackerConf](parent, name, doc) {
-
- /** Creates a param pair with the given value (for Java). */
- override def w(value: TrackerConf): ParamPair[TrackerConf] = super.w(value)
-
- override def jsonEncode(value: TrackerConf): String = {
- import org.json4s.jackson.Serialization
- implicit val formats = Serialization.formats(NoTypeHints)
- compact(render(Extraction.decompose(value)))
- }
-
- override def jsonDecode(json: String): TrackerConf = {
- implicit val formats = DefaultFormats
- val parsedValue = parse(json)
- parsedValue.extract[TrackerConf]
- }
-}
+class CustomObjParam(parent: Params,
+ name: String,
+ doc: String) extends CustomGeneralParam[ObjectiveTrait](parent, name, doc)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala
new file mode 100644
index 000000000000..e9707999a1a1
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/DartBoosterParams.scala
@@ -0,0 +1,61 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark.params
+
+import org.apache.spark.ml.param._
+
+/**
+ * Dart booster parameters, more details can be found at
+ * https://xgboost.readthedocs.io/en/stable/parameter.html#
+ * additional-parameters-for-dart-booster-booster-dart
+ */
+private[spark] trait DartBoosterParams extends Params {
+
+ final val sampleType = new Param[String](this, "sample_type", "Type of sampling algorithm, " +
+ "options: {'uniform', 'weighted'}", ParamValidators.inArray(Array("uniform", "weighted")))
+
+ final def getSampleType: String = $(sampleType)
+
+ final val normalizeType = new Param[String](this, "normalize_type", "type of normalization" +
+ " algorithm, options: {'tree', 'forest'}",
+ ParamValidators.inArray(Array("tree", "forest")))
+
+ final def getNormalizeType: String = $(normalizeType)
+
+ final val rateDrop = new DoubleParam(this, "rate_drop", "Dropout rate (a fraction of previous " +
+ "trees to drop during the dropout)",
+ ParamValidators.inRange(0, 1, true, true))
+
+ final def getRateDrop: Double = $(rateDrop)
+
+ final val oneDrop = new BooleanParam(this, "one_drop", "When this flag is enabled, at least " +
+ "one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout " +
+ "from the original DART paper)")
+
+ final def getOneDrop: Boolean = $(oneDrop)
+
+ final val skipDrop = new DoubleParam(this, "skip_drop", "Probability of skipping the dropout " +
+ "procedure during a boosting iteration.\nIf a dropout is skipped, new trees are added " +
+ "in the same manner as gbtree.\nNote that non-zero skip_drop has higher priority than " +
+ "rate_drop or one_drop.",
+ ParamValidators.inRange(0, 1, true, true))
+
+ final def getSkipDrop: Double = $(skipDrop)
+
+ setDefault(sampleType -> "uniform", normalizeType -> "tree", rateDrop -> 0, skipDrop -> 0)
+
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
index fafbd816a265..e013338fa1f9 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala
@@ -16,303 +16,45 @@
package ml.dmlc.xgboost4j.scala.spark.params
-import com.google.common.base.CaseFormat
-import ml.dmlc.xgboost4j.scala.spark.TrackerConf
-
import org.apache.spark.ml.param._
-import scala.collection.mutable
+/**
+ * General xgboost parameters, more details can be found
+ * at https://xgboost.readthedocs.io/en/stable/parameter.html#general-parameters
+ */
private[spark] trait GeneralParams extends Params {
- /**
- * The number of rounds for boosting
- */
- final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
- ParamValidators.gtEq(1))
- setDefault(numRound, 1)
-
- final def getNumRound: Int = $(numRound)
+ final val booster = new Param[String](this, "booster", "Which booster to use. Can be gbtree, " +
+ "gblinear or dart; gbtree and dart use tree based models while gblinear uses linear " +
+ "functions.", ParamValidators.inArray(Array("gbtree", "dart")))
- /**
- * number of workers used to train xgboost model. default: 1
- */
- final val numWorkers = new IntParam(this, "numWorkers", "number of workers used to run xgboost",
- ParamValidators.gtEq(1))
- setDefault(numWorkers, 1)
+ final def getBooster: String = $(booster)
- final def getNumWorkers: Int = $(numWorkers)
+ final val device = new Param[String](this, "device", "Device for XGBoost to run. User can " +
+ "set it to one of the following values: {cpu, cuda, gpu}",
+ ParamValidators.inArray(Array("cpu", "cuda", "gpu")))
- /**
- * number of threads used by per worker. default 1
- */
- final val nthread = new IntParam(this, "nthread", "number of threads used by per worker",
- ParamValidators.gtEq(1))
- setDefault(nthread, 1)
+ final def getDevice: String = $(device)
- final def getNthread: Int = $(nthread)
-
- /**
- * whether to use external memory as cache. default: false
- */
- final val useExternalMemory = new BooleanParam(this, "useExternalMemory",
- "whether to use external memory as cache")
- setDefault(useExternalMemory, false)
-
- final def getUseExternalMemory: Boolean = $(useExternalMemory)
-
- /**
- * Deprecated. Please use verbosity instead.
- * 0 means printing running messages, 1 means silent mode. default: 0
- */
- final val silent = new IntParam(this, "silent",
- "Deprecated. Please use verbosity instead. " +
- "0 means printing running messages, 1 means silent mode.",
- (value: Int) => value >= 0 && value <= 1)
-
- final def getSilent: Int = $(silent)
-
- /**
- * Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), 3 (debug).
- * default: 1
- */
- final val verbosity = new IntParam(this, "verbosity",
- "Verbosity of printing messages. Valid values are 0 (silent), 1 (warning), 2 (info), " +
- "3 (debug).",
- (value: Int) => value >= 0 && value <= 3)
+ final val verbosity = new IntParam(this, "verbosity", "Verbosity of printing messages. Valid " +
+ "values are 0 (silent), 1 (warning), 2 (info), 3 (debug). Sometimes XGBoost tries to change " +
+ "configurations based on heuristics, which is displayed as warning message. If there's " +
+ "unexpected behaviour, please try to increase value of verbosity.",
+ ParamValidators.inRange(0, 3, true, true))
final def getVerbosity: Int = $(verbosity)
- /**
- * customized objective function provided by user. default: null
- */
- final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
- "provided by user")
-
- /**
- * customized evaluation function provided by user. default: null
- */
- final val customEval = new CustomEvalParam(this, "customEval",
- "customized evaluation function provided by user")
-
- /**
- * the value treated as missing. default: Float.NaN
- */
- final val missing = new FloatParam(this, "missing", "the value treated as missing")
- setDefault(missing, Float.NaN)
+ final val validateParameters = new BooleanParam(this, "validate_parameters", "When set to " +
+ "True, XGBoost will perform validation of input parameters to check whether a parameter " +
+ "is used or not. A warning is emitted when there's unknown parameter.")
- final def getMissing: Float = $(missing)
-
- /**
- * Allows for having a non-zero value for missing when training on prediction
- * on a Sparse or Empty vector.
- */
- final val allowNonZeroForMissing = new BooleanParam(
- this,
- "allowNonZeroForMissing",
- "Allow to have a non-zero value for missing when training or " +
- "predicting on a Sparse or Empty vector. Should only be used if did " +
- "not use Spark's VectorAssembler class to construct the feature vector " +
- "but instead used a method that preserves zeros in your vector."
- )
- setDefault(allowNonZeroForMissing, false)
-
- final def getAllowNonZeroForMissingValue: Boolean = $(allowNonZeroForMissing)
-
- /**
- * The hdfs folder to load and save checkpoint boosters. default: `empty_string`
- */
- final val checkpointPath = new Param[String](this, "checkpointPath", "the hdfs folder to load " +
- "and save checkpoints. If there are existing checkpoints in checkpoint_path. The job will " +
- "load the checkpoint with highest version as the starting point for training. If " +
- "checkpoint_interval is also set, the job will save a checkpoint every a few rounds.")
-
- final def getCheckpointPath: String = $(checkpointPath)
-
- /**
- * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
- * the trained model will get checkpointed every 10 iterations. Note: `checkpoint_path` must
- * also be set if the checkpoint interval is greater than 0.
- */
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval",
- "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the trained " +
- "model will get checkpointed every 10 iterations. Note: `checkpoint_path` must also be " +
- "set if the checkpoint interval is greater than 0.",
- (interval: Int) => interval == -1 || interval >= 1)
-
- final def getCheckpointInterval: Int = $(checkpointInterval)
-
- /**
- * Rabit tracker configurations. The parameter must be provided as an instance of the
- * TrackerConf class, which has the following definition:
- *
- * case class TrackerConf(timeout: Int, hostIp: String, port: Int)
- *
- * See below for detailed explanations.
- *
- * - timeout : The maximum wait time for all workers to connect to the tracker. (in seconds)
- * default: 0 (no timeout)
- *
- * Timeout for constructing the communication group and waiting for the tracker to
- * shutdown when it's instructed to, doesn't apply to communication when tracking
- * is running.
- * The timeout value should take the time of data loading and pre-processing into account,
- * due to potential lazy execution. Alternatively, you may force Spark to
- * perform data transformation before calling XGBoost.train(), so that this timeout truly
- * reflects the connection delay. Set a reasonable timeout value to prevent model
- * training/testing from hanging indefinitely, possible due to network issues.
- * Note that zero timeout value means to wait indefinitely (equivalent to Duration.Inf).
- *
- * - hostIp : The Rabit Tracker host IP address. This is only needed if the host IP
- * cannot be automatically guessed.
- *
- * - port : The port number for the tracker to listen to. Use a system allocated one by
- * default.
- */
- final val trackerConf = new TrackerConfParam(this, "trackerConf", "Rabit tracker configurations")
- setDefault(trackerConf, TrackerConf())
-
- /** Random seed for the C++ part of XGBoost and train/test splitting. */
- final val seed = new LongParam(this, "seed", "random seed")
- setDefault(seed, 0L)
-
- final def getSeed: Long = $(seed)
-
- /** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
- * In native code, the parameter name is feature_name.
- * */
- final val featureNames = new StringArrayParam(this, "feature_names",
- "an array of feature names")
-
- final def getFeatureNames: Array[String] = $(featureNames)
-
- /** Feature types, q is numeric and c is categorical.
- * In native code, the parameter name is feature_type
- * */
- final val featureTypes = new StringArrayParam(this, "feature_types",
- "an array of feature types")
-
- final def getFeatureTypes: Array[String] = $(featureTypes)
-}
+ final def getValidateParameters: Boolean = $(validateParameters)
-trait HasLeafPredictionCol extends Params {
- /**
- * Param for leaf prediction column name.
- * @group param
- */
- final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
- "name of the predictLeaf results")
-
- /** @group getParam */
- final def getLeafPredictionCol: String = $(leafPredictionCol)
-}
-
-trait HasContribPredictionCol extends Params {
- /**
- * Param for contribution prediction column name.
- * @group param
- */
- final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
- "name of the predictContrib results")
-
- /** @group getParam */
- final def getContribPredictionCol: String = $(contribPredictionCol)
-}
-
-trait HasBaseMarginCol extends Params {
-
- /**
- * Param for initial prediction (aka base margin) column name.
- * @group param
- */
- final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
- "Initial prediction (aka base margin) column name.")
-
- /** @group getParam */
- final def getBaseMarginCol: String = $(baseMarginCol)
-}
-
-trait HasGroupCol extends Params {
-
- /**
- * Param for group column name.
- * @group param
- */
- final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
-
- /** @group getParam */
- final def getGroupCol: String = $(groupCol)
-
-}
-
-trait HasNumClass extends Params {
-
- /**
- * number of classes
- */
- final val numClass = new IntParam(this, "numClass", "number of classes")
-
- /** @group getParam */
- final def getNumClass: Int = $(numClass)
-}
-
-/**
- * Trait for shared param featuresCols.
- */
-trait HasFeaturesCols extends Params {
- /**
- * Param for the names of feature columns.
- * @group param
- */
- final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
- "an array of feature column names.")
-
- /** @group getParam */
- final def getFeaturesCols: Array[String] = $(featuresCols)
-
- /** Check if featuresCols is valid */
- def isFeaturesColsValid: Boolean = {
- isDefined(featuresCols) && $(featuresCols) != Array.empty
- }
-
-}
-
-private[spark] trait ParamMapFuncs extends Params {
+ final val nthread = new IntParam(this, "nthread", "Number of threads used by per worker",
+ ParamValidators.gtEq(1))
- def XGBoost2MLlibParams(xgboostParams: Map[String, Any]): Unit = {
- for ((paramName, paramValue) <- xgboostParams) {
- if ((paramName == "booster" && paramValue != "gbtree") ||
- (paramName == "updater" && paramValue != "grow_histmaker,prune" &&
- paramValue != "grow_quantile_histmaker" && paramValue != "grow_gpu_hist")) {
- throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
- s" XGBoost-Spark only supports gbtree as booster type and grow_histmaker or" +
- s" grow_quantile_histmaker or grow_gpu_hist as the updater type")
- }
- val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
- params.find(_.name == name).foreach {
- case _: DoubleParam =>
- set(name, paramValue.toString.toDouble)
- case _: BooleanParam =>
- set(name, paramValue.toString.toBoolean)
- case _: IntParam =>
- set(name, paramValue.toString.toInt)
- case _: FloatParam =>
- set(name, paramValue.toString.toFloat)
- case _: LongParam =>
- set(name, paramValue.toString.toLong)
- case _: Param[_] =>
- set(name, paramValue)
- }
- }
- }
+ final def getNthread: Int = $(nthread)
- def MLlib2XGBoostParams: Map[String, Any] = {
- val xgboostParams = new mutable.HashMap[String, Any]()
- for (param <- params) {
- if (isDefined(param)) {
- val name = CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, param.name)
- xgboostParams += name -> $(param)
- }
- }
- xgboostParams.toMap
- }
+ setDefault(booster -> "gbtree", device -> "cpu", verbosity -> 1, validateParameters -> false,
+ nthread -> 1)
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala
deleted file mode 100644
index 8e57bd9e0cea..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/InferenceParams.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark.params
-
-import org.apache.spark.ml.param.{IntParam, Params}
-
-private[spark] trait InferenceParams extends Params {
-
- /**
- * batch size of inference iteration
- */
- final val inferBatchSize = new IntParam(this, "batchSize", "batch size of inference iteration")
-
- /** @group getParam */
- final def getInferBatchSize: Int = $(inferBatchSize)
-
- setDefault(inferBatchSize, 32 << 10)
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
index b73e6cbaa844..0105ab776ff2 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -20,98 +20,124 @@ import scala.collection.immutable.HashSet
import org.apache.spark.ml.param._
+/**
+ * Specify the learning task and the corresponding learning objective.
+ * More details can be found at
+ * https://xgboost.readthedocs.io/en/stable/parameter.html#learning-task-parameters
+ */
private[spark] trait LearningTaskParams extends Params {
- /**
- * Specify the learning task and the corresponding learning objective.
- * options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw,
- * count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma.
- * default: reg:squarederror
- */
final val objective = new Param[String](this, "objective",
- "objective function used for training")
+ "Objective function used for training",
+ ParamValidators.inArray(LearningTaskParams.SUPPORTED_OBJECTIVES.toArray))
final def getObjective: String = $(objective)
- /**
- * The learning objective type of the specified custom objective and eval.
- * Corresponding type will be assigned if custom objective is defined
- * options: regression, classification. default: null
- */
- final val objectiveType = new Param[String](this, "objectiveType", "objective type used for " +
- s"training, options: {${LearningTaskParams.supportedObjectiveType.mkString(",")}",
- (value: String) => LearningTaskParams.supportedObjectiveType.contains(value))
-
- final def getObjectiveType: String = $(objectiveType)
+ final val numClass = new IntParam(this, "num_class", "Number of classes, used by " +
+ "multi:softmax and multi:softprob objectives", ParamValidators.gtEq(0))
+ final def getNumClass: Int = $(numClass)
- /**
- * the initial prediction score of all instances, global bias. default=0.5
- */
- final val baseScore = new DoubleParam(this, "baseScore", "the initial prediction score of all" +
- " instances, global bias")
+ final val baseScore = new DoubleParam(this, "base_score", "The initial prediction score of " +
+ "all instances, global bias. The parameter is automatically estimated for selected " +
+ "objectives before training. To disable the estimation, specify a real number argument. " +
+ "For sufficient number of iterations, changing this value will not have too much effect.")
final def getBaseScore: Double = $(baseScore)
- /**
- * evaluation metrics for validation data, a default metric will be assigned according to
- * objective(rmse for regression, and error for classification, mean average precision for
- * ranking). options: rmse, rmsle, mae, mape, logloss, error, merror, mlogloss, auc, aucpr, ndcg,
- * map, gamma-deviance
- */
- final val evalMetric = new Param[String](this, "evalMetric", "evaluation metrics for " +
- "validation data, a default metric will be assigned according to objective " +
- "(rmse for regression, and error for classification, mean average precision for ranking)")
+ final val evalMetric = new Param[String](this, "eval_metric", "Evaluation metrics for " +
+ "validation data, a default metric will be assigned according to objective (rmse for " +
+ "regression, and logloss for classification, mean average precision for rank:map, etc.)" +
+ "User can add multiple evaluation metrics. Python users: remember to pass the metrics in " +
+ "as list of parameters pairs instead of map, so that latter eval_metric won't override " +
+ "previous ones", ParamValidators.inArray(LearningTaskParams.SUPPORTED_EVAL_METRICS.toArray))
final def getEvalMetric: String = $(evalMetric)
- /**
- * Fraction of training points to use for testing.
- */
- @Deprecated
- final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
- "fraction of training points to use for testing",
- ParamValidators.inRange(0, 1))
- setDefault(trainTestRatio, 1.0)
-
- @Deprecated
- final def getTrainTestRatio: Double = $(trainTestRatio)
-
- /**
- * whether caching training data
- */
- final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
- "whether caching training data")
-
- /**
- * whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
- * testing
- */
- final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
- "whether cleaning checkpoint data")
-
- /**
- * If non-zero, the training will be stopped after a specified number
- * of consecutive increases in any evaluation metric.
- */
- final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
- "number of rounds of decreasing eval metric to tolerate before " +
- "stopping the training",
- (value: Int) => value == 0 || value > 1)
-
- final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
-
-
- final val maximizeEvaluationMetrics = new BooleanParam(this, "maximizeEvaluationMetrics",
- "define the expected optimization to the evaluation metrics, true to maximize otherwise" +
- " minimize it")
-
- final def getMaximizeEvaluationMetrics: Boolean = $(maximizeEvaluationMetrics)
+ final val seed = new LongParam(this, "seed", "Random number seed.")
-}
+ final def getSeed: Long = $(seed)
-private[spark] object LearningTaskParams {
+ final val seedPerIteration = new BooleanParam(this, "seed_per_iteration", "Seed PRNG " +
+ "determnisticly via iterator number..")
+
+ final def getSeedPerIteration: Boolean = $(seedPerIteration)
+
+ // Parameters for Tweedie Regression (objective=reg:tweedie)
+ final val tweedieVariancePower = new DoubleParam(this, "tweedie_variance_power", "Parameter " +
+ "that controls the variance of the Tweedie distribution var(y) ~ E(y)^tweedie_variance_power.",
+ ParamValidators.inRange(1, 2, false, false))
+
+ final def getTweedieVariancePower: Double = $(tweedieVariancePower)
+
+ // Parameter for using Pseudo-Huber (reg:pseudohubererror)
+ final val huberSlope = new DoubleParam(this, "huber_slope", "A parameter used for Pseudo-Huber " +
+ "loss to define the (delta) term.")
+
+ final def getHuberSlope: Double = $(huberSlope)
+
+ // Parameter for using Quantile Loss (reg:quantileerror) TODO
+
+ // Parameter for using AFT Survival Loss (survival:aft) and Negative
+ // Log Likelihood of AFT metric (aft-nloglik)
+ final val aftLossDistribution = new Param[String](this, "aft_loss_distribution", "Probability " +
+ "Density Function",
+ ParamValidators.inArray(Array("normal", "logistic", "extreme")))
+
+ final def getAftLossDistribution: String = $(aftLossDistribution)
- val supportedObjectiveType = HashSet("regression", "classification")
+ // Parameters for learning to rank (rank:ndcg, rank:map, rank:pairwise)
+ final val lambdarankPairMethod = new Param[String](this, "lambdarank_pair_method", "pairs for " +
+ "pair-wise learning",
+ ParamValidators.inArray(Array("mean", "topk")))
+ final def getLambdarankPairMethod: String = $(lambdarankPairMethod)
+
+ final val lambdarankNumPairPerSample = new IntParam(this, "lambdarank_num_pair_per_sample",
+ "It specifies the number of pairs sampled for each document when pair method is mean, or" +
+ " the truncation level for queries when the pair method is topk. For example, to train " +
+ "with ndcg@6, set lambdarank_num_pair_per_sample to 6 and lambdarank_pair_method to topk",
+ ParamValidators.gtEq(1))
+
+ final def getLambdarankNumPairPerSample: Int = $(lambdarankNumPairPerSample)
+
+ final val lambdarankUnbiased = new BooleanParam(this, "lambdarank_unbiased", "Specify " +
+ "whether do we need to debias input click data.")
+
+ final def getLambdarankUnbiased: Boolean = $(lambdarankUnbiased)
+
+ final val lambdarankBiasNorm = new DoubleParam(this, "lambdarank_bias_norm", "Lp " +
+ "normalization for position debiasing, default is L2. Only relevant when " +
+ "lambdarankUnbiased is set to true.")
+
+ final def getLambdarankBiasNorm: Double = $(lambdarankBiasNorm)
+
+ final val ndcgExpGain = new BooleanParam(this, "ndcg_exp_gain", "Whether we should " +
+ "use exponential gain function for NDCG.")
+
+ final def getNdcgExpGain: Boolean = $(ndcgExpGain)
+
+ setDefault(objective -> "reg:squarederror", numClass -> 0, seed -> 0, seedPerIteration -> false,
+ tweedieVariancePower -> 1.5, huberSlope -> 1, lambdarankPairMethod -> "mean",
+ lambdarankUnbiased -> false, lambdarankBiasNorm -> 2, ndcgExpGain -> true)
+}
+
+private[spark] object LearningTaskParams {
+ val SUPPORTED_OBJECTIVES = HashSet("reg:squarederror", "reg:squaredlogerror", "reg:logistic",
+ "reg:pseudohubererror", "reg:absoluteerror", "reg:quantileerror", "binary:logistic",
+ "binary:logitraw", "binary:hinge", "count:poisson", "survival:cox", "survival:aft",
+ "multi:softmax", "multi:softprob", "rank:ndcg", "rank:map", "rank:pairwise", "reg:gamma",
+ "reg:tweedie")
+
+ val BINARY_CLASSIFICATION_OBJS = HashSet("binary:logistic", "binary:hinge", "binary:logitraw")
+ val MULTICLASSIFICATION_OBJS = HashSet("multi:softmax", "multi:softprob")
+ val RANKER_OBJS = HashSet("rank:ndcg", "rank:map", "rank:pairwise")
+ val REGRESSION_OBJS = SUPPORTED_OBJECTIVES -- BINARY_CLASSIFICATION_OBJS --
+ MULTICLASSIFICATION_OBJS -- RANKER_OBJS
+
+ val SUPPORTED_EVAL_METRICS = HashSet("rmse", "rmsle", "mae", "mape", "mphe", "logloss", "error",
+ "error@t", "merror", "mlogloss", "auc", "aucpr", "pre", "ndcg", "map", "ndcg@n", "map@n",
+ "pre@n", "ndcg-", "map-", "ndcg@n-", "map@n-", "poisson-nloglik", "gamma-nloglik",
+ "cox-nloglik", "gamma-deviance", "tweedie-nloglik", "aft-nloglik",
+ "interval-regression-accuracy")
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/NonParamVariables.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/NonParamVariables.scala
deleted file mode 100644
index 276a938e0c8a..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/NonParamVariables.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- Copyright (c) 2014 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark.params
-
-import org.apache.spark.sql.DataFrame
-
-trait NonParamVariables {
- protected var evalSetsMap: Map[String, DataFrame] = Map.empty
-
- def setEvalSets(evalSets: Map[String, DataFrame]): this.type = {
- evalSetsMap = evalSets
- this
- }
-
- def getEvalSets(params: Map[String, Any]): Map[String, DataFrame] = {
- if (params.contains("eval_sets")) {
- params("eval_sets").asInstanceOf[Map[String, DataFrame]]
- } else {
- evalSetsMap
- }
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala
new file mode 100644
index 000000000000..787cd753ba11
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/ParamMapConversion.scala
@@ -0,0 +1,65 @@
+/*
+ Copyright (c) 2014-2022 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark.params
+
+import scala.collection.mutable
+
+import org.apache.spark.ml.param._
+
+private[spark] trait ParamMapConversion extends NonXGBoostParams {
+
+ /**
+ * Convert XGBoost parameters to Spark Parameters
+ *
+ * @param xgboostParams XGBoost style parameters
+ */
+ def xgboost2SparkParams(xgboostParams: Map[String, Any]): Unit = {
+ for ((name, paramValue) <- xgboostParams) {
+ params.find(_.name == name).foreach {
+ case _: DoubleParam =>
+ set(name, paramValue.toString.toDouble)
+ case _: BooleanParam =>
+ set(name, paramValue.toString.toBoolean)
+ case _: IntParam =>
+ set(name, paramValue.toString.toInt)
+ case _: FloatParam =>
+ set(name, paramValue.toString.toFloat)
+ case _: LongParam =>
+ set(name, paramValue.toString.toLong)
+ case _: Param[_] =>
+ set(name, paramValue)
+ }
+ }
+ }
+
+ /**
+ * Convert the user-supplied parameters to the XGBoost parameters.
+ *
+ * Note that this also contains jvm-specific parameters.
+ */
+ def getXGBoostParams: Map[String, Any] = {
+ val xgboostParams = new mutable.HashMap[String, Any]()
+
+ // Only pass user-supplied parameters to xgboost.
+ for (param <- params) {
+ if (isSet(param) && !nonXGBoostParams.contains(param.name)) {
+ xgboostParams += param.name -> $(param)
+ }
+ }
+ xgboostParams.toMap
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala
index 27ada633c63d..7a527fb37fc8 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/RabitParams.scala
@@ -18,25 +18,27 @@ package ml.dmlc.xgboost4j.scala.spark.params
import org.apache.spark.ml.param._
-private[spark] trait RabitParams extends Params {
- /**
- * Rabit parameters passed through Rabit.Init into native layer
- * rabit_ring_reduce_threshold - minimal threshold to enable ring based allreduce operation
- * rabit_timeout - wait interval before exit after rabit observed failures set -1 to disable
- * dmlc_worker_connect_retry - number of retrys to tracker
- * dmlc_worker_stop_process_on_error - exit process when rabit see assert/error
- */
- final val rabitRingReduceThreshold = new IntParam(this, "rabitRingReduceThreshold",
- "threshold count to enable allreduce/broadcast with ring based topology",
- ParamValidators.gtEq(1))
- setDefault(rabitRingReduceThreshold, (32 << 10))
-
- final def rabitTimeout: IntParam = new IntParam(this, "rabitTimeout",
- "timeout threshold after rabit observed failures")
- setDefault(rabitTimeout, -1)
-
- final def rabitConnectRetry: IntParam = new IntParam(this, "dmlcWorkerConnectRetry",
- "number of retry worker do before fail", ParamValidators.gtEq(1))
- setDefault(rabitConnectRetry, 5)
+private[spark] trait RabitParams extends Params with NonXGBoostParams {
+ final val rabitTrackerTimeout = new IntParam(this, "rabitTrackerTimeout", "The number of " +
+ "seconds before timeout waiting for workers to connect. and for the tracker to shutdown.",
+ ParamValidators.gtEq(0))
+
+ final def getRabitTrackerTimeout: Int = $(rabitTrackerTimeout)
+
+ final val rabitTrackerHostIp = new Param[String](this, "rabitTrackerHostIp", "The Rabit " +
+ "Tracker host IP address. This is only needed if the host IP cannot be automatically " +
+ "guessed.")
+
+ final def getRabitTrackerHostIp: String = $(rabitTrackerHostIp)
+
+ final val rabitTrackerPort = new IntParam(this, "rabitTrackerPort", "The port number for the " +
+ "tracker to listen to. Use a system allocated one by default.",
+ ParamValidators.gtEq(0))
+
+ final def getRabitTrackerPort: Int = $(rabitTrackerPort)
+
+ setDefault(rabitTrackerTimeout -> 0, rabitTrackerHostIp -> "", rabitTrackerPort -> 0)
+
+ addNonXGBoostParam(rabitTrackerPort, rabitTrackerHostIp, rabitTrackerPort)
}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala
new file mode 100644
index 000000000000..7ea5966d459a
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/TreeBoosterParams.scala
@@ -0,0 +1,228 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark.params
+
+import scala.collection.immutable.HashSet
+
+import org.apache.spark.ml.param._
+
+/**
+ * TreeBoosterParams defines the XGBoost TreeBooster parameters for Spark
+ *
+ * The details can be found at
+ * https://xgboost.readthedocs.io/en/stable/parameter.html#parameters-for-tree-booster
+ */
+private[spark] trait TreeBoosterParams extends Params {
+
+ final val eta = new DoubleParam(this, "eta", "Step size shrinkage used in update to prevents " +
+ "overfitting. After each boosting step, we can directly get the weights of new features, " +
+ "and eta shrinks the feature weights to make the boosting process more conservative.",
+ ParamValidators.inRange(0, 1, lowerInclusive = true, upperInclusive = true))
+
+ final def getEta: Double = $(eta)
+
+ final val gamma = new DoubleParam(this, "gamma", "Minimum loss reduction required to make a " +
+ "further partition on a leaf node of the tree. The larger gamma is, the more conservative " +
+ "the algorithm will be.",
+ ParamValidators.gtEq(0))
+
+ final def getGamma: Double = $(gamma)
+
+ final val maxDepth = new IntParam(this, "max_depth", "Maximum depth of a tree. Increasing this " +
+ "value will make the model more complex and more likely to overfit. 0 indicates no limit " +
+ "on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. " +
+ "exact tree method requires non-zero value.",
+ ParamValidators.gtEq(0))
+
+ final def getMaxDepth: Int = $(maxDepth)
+
+ final val minChildWeight = new DoubleParam(this, "min_child_weight", "Minimum sum of instance " +
+ "weight (hessian) needed in a child. If the tree partition step results in a leaf node " +
+ "with the sum of instance weight less than min_child_weight, then the building process " +
+ "will give up further partitioning. In linear regression task, this simply corresponds " +
+ "to minimum number of instances needed to be in each node. The larger min_child_weight " +
+ "is, the more conservative the algorithm will be.",
+ ParamValidators.gtEq(0))
+
+ final def getMinChildWeight: Double = $(minChildWeight)
+
+ final val maxDeltaStep = new DoubleParam(this, "max_delta_step", "Maximum delta step we allow " +
+ "each leaf output to be. If the value is set to 0, it means there is no constraint. If it " +
+ "is set to a positive value, it can help making the update step more conservative. Usually " +
+ "this parameter is not needed, but it might help in logistic regression when class is " +
+ "extremely imbalanced. Set it to value of 1-10 might help control the update.",
+ ParamValidators.gtEq(0))
+
+ final def getMaxDeltaStep: Double = $(maxDeltaStep)
+
+ final val subsample = new DoubleParam(this, "subsample", "Subsample ratio of the training " +
+ "instances. Setting it to 0.5 means that XGBoost would randomly sample half of the " +
+ "training data prior to growing trees. and this will prevent overfitting. Subsampling " +
+ "will occur once in every boosting iteration.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ final def getSubsample: Double = $(subsample)
+
+ final val samplingMethod = new Param[String](this, "sampling_method", "The method to use to " +
+ "sample the training instances. The supported sampling methods" +
+ "uniform: each training instance has an equal probability of being selected. Typically set " +
+ "subsample >= 0.5 for good results.\n" +
+ "gradient_based: the selection probability for each training instance is proportional to " +
+ "the regularized absolute value of gradients. subsample may be set to as low as 0.1 " +
+ "without loss of model accuracy. Note that this sampling method is only supported when " +
+ "tree_method is set to hist and the device is cuda; other tree methods only support " +
+ "uniform sampling.",
+ ParamValidators.inArray(Array("uniform", "gradient_based")))
+
+ final def getSamplingMethod: String = $(samplingMethod)
+
+ final val colsampleBytree = new DoubleParam(this, "colsample_bytree", "Subsample ratio of " +
+ "columns when constructing each tree. Subsampling occurs once for every tree constructed.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ final def getColsampleBytree: Double = $(colsampleBytree)
+
+
+ final val colsampleBylevel = new DoubleParam(this, "colsample_bylevel", "Subsample ratio of " +
+ "columns for each level. Subsampling occurs once for every new depth level reached in a " +
+ "tree. Columns are subsampled from the set of columns chosen for the current tree.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ final def getColsampleBylevel: Double = $(colsampleBylevel)
+
+
+ final val colsampleBynode = new DoubleParam(this, "colsample_bynode", "Subsample ratio of " +
+ "columns for each node (split). Subsampling occurs once every time a new split is " +
+ "evaluated. Columns are subsampled from the set of columns chosen for the current level.",
+ ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+ final def getColsampleBynode: Double = $(colsampleBynode)
+
+
+ /**
+ * L2 regularization term on weights, increase this value will make model more conservative.
+ * [default=1]
+ */
+ final val lambda = new DoubleParam(this, "lambda", "L2 regularization term on weights. " +
+ "Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
+
+ final def getLambda: Double = $(lambda)
+
+ final val alpha = new DoubleParam(this, "alpha", "L1 regularization term on weights. " +
+ "Increasing this value will make model more conservative.", ParamValidators.gtEq(0))
+
+ final def getAlpha: Double = $(alpha)
+
+ final val treeMethod = new Param[String](this, "tree_method", "The tree construction " +
+ "algorithm used in XGBoost, options: {'auto', 'exact', 'approx', 'hist', 'gpu_hist'}",
+ ParamValidators.inArray(BoosterParams.supportedTreeMethods.toArray))
+
+ final def getTreeMethod: String = $(treeMethod)
+
+ final val scalePosWeight = new DoubleParam(this, "scale_pos_weight", "Control the balance of " +
+ "positive and negative weights, useful for unbalanced classes. A typical value to consider: " +
+ "sum(negative instances) / sum(positive instances)")
+
+ final def getScalePosWeight: Double = $(scalePosWeight)
+
+ final val updater = new Param[String](this, "updater", "A comma separated string defining the " +
+ "sequence of tree updaters to run, providing a modular way to construct and to modify the " +
+ "trees. This is an advanced parameter that is usually set automatically, depending on some " +
+ "other parameters. However, it could be also set explicitly by a user. " +
+ "The following updaters exist:\n" +
+ "grow_colmaker: non-distributed column-based construction of trees.\n" +
+ "grow_histmaker: distributed tree construction with row-based data splitting based on " +
+ "global proposal of histogram counting.\n" +
+ "grow_quantile_histmaker: Grow tree using quantized histogram.\n" +
+ "grow_gpu_hist: Enabled when tree_method is set to hist along with device=cuda.\n" +
+ "grow_gpu_approx: Enabled when tree_method is set to approx along with device=cuda.\n" +
+ "sync: synchronizes trees in all distributed nodes.\n" +
+ "refresh: refreshes tree's statistics and or leaf values based on the current data. Note " +
+ "that no random subsampling of data rows is performed.\n" +
+ "prune: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth " +
+ "greater than max_depth.",
+ (value: String) => value.split(",").forall(
+ ParamValidators.inArray(BoosterParams.supportedUpdaters.toArray)))
+
+ final def getUpdater: String = $(updater)
+
+ final val refreshLeaf = new BooleanParam(this, "refresh_leaf", "This is a parameter of the " +
+ "refresh updater. When this flag is 1, tree leafs as well as tree nodes' stats are updated. " +
+ "When it is 0, only node stats are updated.")
+
+ final def getRefreshLeaf: Boolean = $(refreshLeaf)
+
+ // TODO set updater/refreshLeaf defaul value
+ final val processType = new Param[String](this, "process_type", "A type of boosting process to " +
+ "run. options: {default, update}",
+ ParamValidators.inArray(Array("default", "update")))
+
+ final def getProcessType: String = $(processType)
+
+ final val growPolicy = new Param[String](this, "grow_policy", "Controls a way new nodes are " +
+ "added to the tree. Currently supported only if tree_method is set to hist or approx. " +
+ "Choices: depthwise, lossguide. depthwise: split at nodes closest to the root. " +
+ "lossguide: split at nodes with highest loss change.",
+ ParamValidators.inArray(Array("depthwise", "lossguide")))
+
+ final def getGrowPolicy: String = $(growPolicy)
+
+
+ final val maxLeaves = new IntParam(this, "max_leaves", "Maximum number of nodes to be added. " +
+ "Not used by exact tree method", ParamValidators.gtEq(0))
+
+ final def getMaxLeaves: Int = $(maxLeaves)
+
+ final val maxBins = new IntParam(this, "max_bin", "Maximum number of discrete bins to bucket " +
+ "continuous features. Increasing this number improves the optimality of splits at the cost " +
+ "of higher computation time. Only used if tree_method is set to hist or approx.",
+ ParamValidators.gt(0))
+
+ final def getMaxBins: Int = $(maxBins)
+
+ final val numParallelTree = new IntParam(this, "num_parallel_tree", "Number of parallel trees " +
+ "constructed during each iteration. This option is used to support boosted random forest.",
+ ParamValidators.gt(0))
+
+ final def getNumParallelTree: Int = $(numParallelTree)
+
+ final val monotoneConstraints = new IntArrayParam(this, "monotone_constraints", "Constraint of " +
+ "variable monotonicity.")
+
+ final def getMonotoneConstraints: Array[Int] = $(monotoneConstraints)
+
+ final val maxCachedHistNode = new IntParam(this, "max_cached_hist_node", "Maximum number of " +
+ "cached nodes for CPU histogram.",
+ ParamValidators.gt(0))
+
+ final def getMaxCachedHistNode: Int = $(maxCachedHistNode)
+
+ setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0,
+ subsample -> 1, samplingMethod -> "uniform", colsampleBytree -> 1, colsampleBylevel -> 1,
+ colsampleBynode -> 1, lambda -> 1, alpha -> 0, treeMethod -> "auto", scalePosWeight -> 1,
+ processType -> "default", growPolicy -> "depthwise", maxLeaves -> 0, maxBins -> 256,
+ numParallelTree -> 1, maxCachedHistNode -> 65536)
+
+}
+
+private[spark] object BoosterParams {
+
+ val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist", "gpu_hist")
+
+ val supportedUpdaters = HashSet("grow_colmaker", "grow_histmaker", "grow_quantile_histmaker",
+ "grow_gpu_hist", "grow_gpu_approx", "sync", "refresh", "prune")
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala
deleted file mode 100644
index 9581ea0f2c59..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostEstimatorCommon.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark.params
-
-import org.apache.spark.ml.feature.VectorAssembler
-import org.apache.spark.ml.param.{Param, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol, HasWeightCol}
-import org.apache.spark.ml.util.XGBoostSchemaUtils
-import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.types.StructType
-
-private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with LearningTaskParams
- with BoosterParams with RabitParams with ParamMapFuncs with NonParamVariables with HasWeightCol
- with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
- with HasLabelCol with HasFeaturesCols with HasHandleInvalid {
-
- def needDeterministicRepartitioning: Boolean = {
- isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
- isDefined(checkpointInterval) && getCheckpointInterval > 0
- }
-
- /**
- * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
- * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
- * output). Column lengths are taken from the size of ML Attribute Group, which can be set using
- * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
- * from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
- * Default: "error"
- * @group param
- */
- override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
- """Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
- |rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
- |in the output). Column lengths are taken from the size of ML Attribute Group, which can be
- |set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
- |be inferred from first rows of the data since it is safe to do so but only in case of 'error'
- |or 'skip'.""".stripMargin.replaceAll("\n", " "),
- ParamValidators.inArray(Array("skip", "error", "keep")))
-
- setDefault(handleInvalid, "error")
-
- /**
- * Specify an array of feature column names which must be numeric types.
- */
- def setFeaturesCol(value: Array[String]): this.type = set(featuresCols, value)
-
- /** Set the handleInvalid for VectorAssembler */
- def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
-
- /**
- * Check if schema has a field named with the value of "featuresCol" param and it's data type
- * must be VectorUDT
- */
- def isFeaturesColSet(schema: StructType): Boolean = {
- schema.fieldNames.contains(getFeaturesCol) &&
- XGBoostSchemaUtils.isVectorUDFType(schema(getFeaturesCol).dataType)
- }
-
- /** check the features columns type */
- def transformSchemaWithFeaturesCols(fit: Boolean, schema: StructType): StructType = {
- if (isFeaturesColsValid) {
- if (fit) {
- XGBoostSchemaUtils.checkNumericType(schema, $(labelCol))
- }
- $(featuresCols).foreach(feature =>
- XGBoostSchemaUtils.checkFeatureColumnType(schema(feature).dataType))
- schema
- } else {
- throw new IllegalArgumentException("featuresCol or featuresCols must be specified")
- }
- }
-
- /**
- * Vectorize the features columns if necessary.
- *
- * @param input the input dataset
- * @return (output dataset and the feature column name)
- */
- def vectorize(input: Dataset[_]): (Dataset[_], String) = {
- val schema = input.schema
- if (isFeaturesColSet(schema)) {
- // Dataset already has vectorized.
- (input, getFeaturesCol)
- } else if (isFeaturesColsValid) {
- val featuresName = if (!schema.fieldNames.contains(getFeaturesCol)) {
- getFeaturesCol
- } else {
- "features_" + uid
- }
- val vectorAssembler = new VectorAssembler()
- .setHandleInvalid($(handleInvalid))
- .setInputCols(getFeaturesCols)
- .setOutputCol(featuresName)
- (vectorAssembler.transform(input).select(featuresName, getLabelCol), featuresName)
- } else {
- // never reach here, since transformSchema will take care of the case
- // that featuresCols is invalid
- (input, getFeaturesCol)
- }
- }
-}
-
-private[scala] trait XGBoostClassifierParams extends XGBoostEstimatorCommon with HasNumClass
-
-private[scala] trait XGBoostRegressorParams extends XGBoostEstimatorCommon with HasGroupCol
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala
new file mode 100644
index 000000000000..8345cab35149
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala
@@ -0,0 +1,356 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark.params
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.sql.types.StructType
+
+import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
+
+trait HasLeafPredictionCol extends Params {
+ /**
+ * Param for leaf prediction column name.
+ *
+ * @group param
+ */
+ final val leafPredictionCol: Param[String] = new Param[String](this, "leafPredictionCol",
+ "name of the predictLeaf results")
+
+ /** @group getParam */
+ final def getLeafPredictionCol: String = $(leafPredictionCol)
+}
+
+trait HasContribPredictionCol extends Params {
+ /**
+ * Param for contribution prediction column name.
+ *
+ * @group param
+ */
+ final val contribPredictionCol: Param[String] = new Param[String](this, "contribPredictionCol",
+ "name of the predictContrib results")
+
+ /** @group getParam */
+ final def getContribPredictionCol: String = $(contribPredictionCol)
+}
+
+trait HasBaseMarginCol extends Params {
+
+ /**
+ * Param for initial prediction (aka base margin) column name.
+ *
+ * @group param
+ */
+ final val baseMarginCol: Param[String] = new Param[String](this, "baseMarginCol",
+ "Initial prediction (aka base margin) column name.")
+
+ /** @group getParam */
+ final def getBaseMarginCol: String = $(baseMarginCol)
+
+}
+
+trait HasGroupCol extends Params {
+
+ final val groupCol: Param[String] = new Param[String](this, "groupCol", "group column name.")
+
+ /** @group getParam */
+ final def getGroupCol: String = $(groupCol)
+}
+
+/**
+ * Trait for shared param featuresCols.
+ */
+trait HasFeaturesCols extends Params {
+ /**
+ * Param for the names of feature columns.
+ *
+ * @group param
+ */
+ final val featuresCols: StringArrayParam = new StringArrayParam(this, "featuresCols",
+ "An array of feature column names.")
+
+ /** @group getParam */
+ final def getFeaturesCols: Array[String] = $(featuresCols)
+
+ /** Check if featuresCols is valid */
+ def isFeaturesColsValid: Boolean = {
+ isDefined(featuresCols) && $(featuresCols) != Array.empty
+ }
+}
+
+/**
+ * A trait to hold non-xgboost parameters
+ */
+trait NonXGBoostParams extends Params {
+ private val paramNames: ArrayBuffer[String] = ArrayBuffer.empty
+
+ protected def addNonXGBoostParam(ps: Param[_]*): Unit = {
+ ps.foreach(p => paramNames.append(p.name))
+ }
+
+ protected lazy val nonXGBoostParams: Array[String] = paramNames.toSet.toArray
+}
+
+/**
+ * XGBoost spark-specific parameters which should not be passed
+ * into the xgboost library
+ *
+ * @tparam T should be the XGBoost estimators or models
+ */
+private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFeaturesCol
+ with HasLabelCol with HasBaseMarginCol with HasWeightCol with HasPredictionCol
+ with HasLeafPredictionCol with HasContribPredictionCol
+ with RabitParams with NonXGBoostParams with SchemaValidationTrait {
+
+ final val numWorkers = new IntParam(this, "numWorkers", "Number of workers used to train xgboost",
+ ParamValidators.gtEq(1))
+
+ final def getNumRound: Int = $(numRound)
+
+ final val forceRepartition = new BooleanParam(this, "forceRepartition", "If the partition " +
+ "is equal to numWorkers, xgboost won't repartition the dataset. Set forceRepartition to " +
+ "true to force repartition.")
+
+ final def getForceRepartition: Boolean = $(forceRepartition)
+
+ final val numRound = new IntParam(this, "numRound", "The number of rounds for boosting",
+ ParamValidators.gtEq(1))
+
+ final val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds", "Stop training " +
+ "Number of rounds of decreasing eval metric to tolerate before stopping training",
+ ParamValidators.gtEq(0))
+
+ final def getNumEarlyStoppingRounds: Int = $(numEarlyStoppingRounds)
+
+ final val inferBatchSize = new IntParam(this, "inferBatchSize", "batch size in rows " +
+ "to be grouped for inference",
+ ParamValidators.gtEq(1))
+
+ /** @group getParam */
+ final def getInferBatchSize: Int = $(inferBatchSize)
+
+ /**
+ * the value treated as missing. default: Float.NaN
+ */
+ final val missing = new FloatParam(this, "missing", "The value treated as missing")
+
+ final def getMissing: Float = $(missing)
+
+ final val customObj = new CustomObjParam(this, "customObj", "customized objective function " +
+ "provided by user")
+
+ final def getCustomObj: ObjectiveTrait = $(customObj)
+
+ final val customEval = new CustomEvalParam(this, "customEval",
+ "customized evaluation function provided by user")
+
+ final def getCustomEval: EvalTrait = $(customEval)
+
+ /** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
+ * In native code, the parameter name is feature_name.
+ * */
+ final val featureNames = new StringArrayParam(this, "feature_names",
+ "an array of feature names")
+
+ final def getFeatureNames: Array[String] = $(featureNames)
+
+ /** Feature types, q is numeric and c is categorical.
+ * In native code, the parameter name is feature_type
+ * */
+ final val featureTypes = new StringArrayParam(this, "feature_types",
+ "an array of feature types")
+
+ final def getFeatureTypes: Array[String] = $(featureTypes)
+
+ setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
+ numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
+ featuresCols -> Array.empty, customObj -> null, customEval -> null,
+ featureNames -> Array.empty, featureTypes -> Array.empty)
+
+ addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
+ labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
+ forceRepartition, missing, featuresCols, customEval, customObj, featureTypes, featureNames)
+
+ final def getNumWorkers: Int = $(numWorkers)
+
+ def setNumWorkers(value: Int): T = set(numWorkers, value).asInstanceOf[T]
+
+ def setForceRepartition(value: Boolean): T = set(forceRepartition, value).asInstanceOf[T]
+
+ def setNumRound(value: Int): T = set(numRound, value).asInstanceOf[T]
+
+ def setFeaturesCol(value: Array[String]): T = set(featuresCols, value).asInstanceOf[T]
+
+ def setBaseMarginCol(value: String): T = set(baseMarginCol, value).asInstanceOf[T]
+
+ def setWeightCol(value: String): T = set(weightCol, value).asInstanceOf[T]
+
+ def setLeafPredictionCol(value: String): T = set(leafPredictionCol, value).asInstanceOf[T]
+
+ def setContribPredictionCol(value: String): T = set(contribPredictionCol, value).asInstanceOf[T]
+
+ def setInferBatchSize(value: Int): T = set(inferBatchSize, value).asInstanceOf[T]
+
+ def setMissing(value: Float): T = set(missing, value).asInstanceOf[T]
+
+ def setCustomObj(value: ObjectiveTrait): T = set(customObj, value).asInstanceOf[T]
+
+ def setCustomEval(value: EvalTrait): T = set(customEval, value).asInstanceOf[T]
+
+ def setRabitTrackerTimeout(value: Int): T = set(rabitTrackerTimeout, value).asInstanceOf[T]
+
+ def setRabitTrackerHostIp(value: String): T = set(rabitTrackerHostIp, value).asInstanceOf[T]
+
+ def setRabitTrackerPort(value: Int): T = set(rabitTrackerPort, value).asInstanceOf[T]
+
+ def setFeatureNames(value: Array[String]): T = set(featureNames, value).asInstanceOf[T]
+
+ def setFeatureTypes(value: Array[String]): T = set(featureTypes, value).asInstanceOf[T]
+}
+
+private[spark] trait SchemaValidationTrait {
+
+ def validateAndTransformSchema(schema: StructType,
+ fitting: Boolean): StructType = schema
+}
+
+/**
+ * XGBoost ranking spark-specific parameters
+ *
+ * @tparam T should be XGBoostRanker or XGBoostRankingModel
+ */
+private[spark] trait RankerParams[T <: Params] extends HasGroupCol with NonXGBoostParams {
+ def setGroupCol(value: String): T = set(groupCol, value).asInstanceOf[T]
+
+ addNonXGBoostParam(groupCol)
+}
+
+/**
+ * XGBoost-specific parameters to pass into xgboost libraray
+ *
+ * @tparam T should be the XGBoost estimators or models
+ */
+private[spark] trait XGBoostParams[T <: Params] extends TreeBoosterParams
+ with LearningTaskParams with GeneralParams with DartBoosterParams {
+
+ // Setters for TreeBoosterParams
+ def setEta(value: Double): T = set(eta, value).asInstanceOf[T]
+
+ def setGamma(value: Double): T = set(gamma, value).asInstanceOf[T]
+
+ def setMaxDepth(value: Int): T = set(maxDepth, value).asInstanceOf[T]
+
+ def setMinChildWeight(value: Double): T = set(minChildWeight, value).asInstanceOf[T]
+
+ def setMaxDeltaStep(value: Double): T = set(maxDeltaStep, value).asInstanceOf[T]
+
+ def setSubsample(value: Double): T = set(subsample, value).asInstanceOf[T]
+
+ def setSamplingMethod(value: String): T = set(samplingMethod, value).asInstanceOf[T]
+
+ def setColsampleBytree(value: Double): T = set(colsampleBytree, value).asInstanceOf[T]
+
+ def setColsampleBylevel(value: Double): T = set(colsampleBylevel, value).asInstanceOf[T]
+
+ def setColsampleBynode(value: Double): T = set(colsampleBynode, value).asInstanceOf[T]
+
+ def setLambda(value: Double): T = set(lambda, value).asInstanceOf[T]
+
+ def setAlpha(value: Double): T = set(alpha, value).asInstanceOf[T]
+
+ def setTreeMethod(value: String): T = set(treeMethod, value).asInstanceOf[T]
+
+ def setScalePosWeight(value: Double): T = set(scalePosWeight, value).asInstanceOf[T]
+
+ def setUpdater(value: String): T = set(updater, value).asInstanceOf[T]
+
+ def setRefreshLeaf(value: Boolean): T = set(refreshLeaf, value).asInstanceOf[T]
+
+ def setProcessType(value: String): T = set(processType, value).asInstanceOf[T]
+
+ def setGrowPolicy(value: String): T = set(growPolicy, value).asInstanceOf[T]
+
+ def setMaxLeaves(value: Int): T = set(maxLeaves, value).asInstanceOf[T]
+
+ def setMaxBins(value: Int): T = set(maxBins, value).asInstanceOf[T]
+
+ def setNumParallelTree(value: Int): T = set(numParallelTree, value).asInstanceOf[T]
+
+ def setMaxCachedHistNode(value: Int): T = set(maxCachedHistNode, value).asInstanceOf[T]
+
+ // Setters for LearningTaskParams
+
+ def setObjective(value: String): T = set(objective, value).asInstanceOf[T]
+
+ def setNumClass(value: Int): T = set(numClass, value).asInstanceOf[T]
+
+ def setBaseScore(value: Double): T = set(baseScore, value).asInstanceOf[T]
+
+ def setEvalMetric(value: String): T = set(evalMetric, value).asInstanceOf[T]
+
+ def setSeed(value: Long): T = set(seed, value).asInstanceOf[T]
+
+ def setSeedPerIteration(value: Boolean): T = set(seedPerIteration, value).asInstanceOf[T]
+
+ def setTweedieVariancePower(value: Double): T = set(tweedieVariancePower, value).asInstanceOf[T]
+
+ def setHuberSlope(value: Double): T = set(huberSlope, value).asInstanceOf[T]
+
+ def setAftLossDistribution(value: String): T = set(aftLossDistribution, value).asInstanceOf[T]
+
+ def setLambdarankPairMethod(value: String): T = set(lambdarankPairMethod, value).asInstanceOf[T]
+
+ def setLambdarankNumPairPerSample(value: Int): T =
+ set(lambdarankNumPairPerSample, value).asInstanceOf[T]
+
+ def setLambdarankUnbiased(value: Boolean): T = set(lambdarankUnbiased, value).asInstanceOf[T]
+
+ def setLambdarankBiasNorm(value: Double): T = set(lambdarankBiasNorm, value).asInstanceOf[T]
+
+ def setNdcgExpGain(value: Boolean): T = set(ndcgExpGain, value).asInstanceOf[T]
+
+ // Setters for Dart
+ def setSampleType(value: String): T = set(sampleType, value).asInstanceOf[T]
+
+ def setNormalizeType(value: String): T = set(normalizeType, value).asInstanceOf[T]
+
+ def setRateDrop(value: Double): T = set(rateDrop, value).asInstanceOf[T]
+
+ def setOneDrop(value: Boolean): T = set(oneDrop, value).asInstanceOf[T]
+
+ def setSkipDrop(value: Double): T = set(skipDrop, value).asInstanceOf[T]
+
+ // Setters for GeneralParams
+ def setBooster(value: String): T = set(booster, value).asInstanceOf[T]
+
+ def setDevice(value: String): T = set(device, value).asInstanceOf[T]
+
+ def setVerbosity(value: Int): T = set(verbosity, value).asInstanceOf[T]
+
+ def setValidateParameters(value: Boolean): T = set(validateParameters, value).asInstanceOf[T]
+
+ def setNthread(value: Int): T = set(nthread, value).asInstanceOf[T]
+}
+
+private[spark] trait ParamUtils[T <: Params] extends Params {
+
+ def isDefinedNonEmpty(param: Param[String]): Boolean = {
+ isDefined(param) && $(param).nonEmpty
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala
deleted file mode 100644
index acc605b1f0a5..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/util/DataUtils.scala
+++ /dev/null
@@ -1,229 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark.util
-
-import scala.collection.mutable
-
-import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
-
-import org.apache.spark.HashPartitioner
-import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
-import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{FloatType, IntegerType}
-import org.apache.spark.sql.{Column, DataFrame, Row}
-
-object DataUtils extends Serializable {
- private[spark] implicit class XGBLabeledPointFeatures(
- val labeledPoint: XGBLabeledPoint
- ) extends AnyVal {
- /** Converts the point to [[MLLabeledPoint]]. */
- private[spark] def asML: MLLabeledPoint = {
- MLLabeledPoint(labeledPoint.label, labeledPoint.features)
- }
-
- /**
- * Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]].
- */
- def features: Vector = if (labeledPoint.indices == null) {
- Vectors.dense(labeledPoint.values.map(_.toDouble))
- } else {
- Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble))
- }
- }
-
- private[spark] implicit class MLLabeledPointToXGBLabeledPoint(
- val labeledPoint: MLLabeledPoint
- ) extends AnyVal {
- /** Converts an [[MLLabeledPoint]] to an [[XGBLabeledPoint]]. */
- def asXGB: XGBLabeledPoint = {
- labeledPoint.features.asXGB.copy(label = labeledPoint.label.toFloat)
- }
- }
-
- private[spark] implicit class MLVectorToXGBLabeledPoint(val v: Vector) extends AnyVal {
- /**
- * Converts a [[Vector]] to a data point with a dummy label.
- *
- * This is needed for constructing a [[ml.dmlc.xgboost4j.scala.DMatrix]]
- * for prediction.
- */
- def asXGB: XGBLabeledPoint = v match {
- case v: DenseVector =>
- XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat))
- case v: SparseVector =>
- XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat))
- }
- }
-
- private def attachPartitionKey(
- row: Row,
- deterministicPartition: Boolean,
- numWorkers: Int,
- xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
- if (deterministicPartition) {
- (math.abs(row.hashCode() % numWorkers), xgbLp)
- } else {
- (1, xgbLp)
- }
- }
-
- private def repartitionRDDs(
- deterministicPartition: Boolean,
- numWorkers: Int,
- arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
- if (deterministicPartition) {
- arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
- rdd => rdd.map(_._2)
- }
- } else {
- arrayOfRDDs.map(rdd => {
- if (rdd.getNumPartitions != numWorkers) {
- rdd.map(_._2).repartition(numWorkers)
- } else {
- rdd.map(_._2)
- }
- })
- }
- }
-
- /** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
- private[spark] case class PackedParams(labelCol: Column,
- featuresCol: Column,
- weight: Column,
- baseMargin: Column,
- group: Option[Column],
- numWorkers: Int,
- deterministicPartition: Boolean)
-
- /**
- * convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
- *
- * First, it serves converting each instance of input into XGBLabeledPoint
- * Second, it repartition the RDD to the number workers.
- *
- */
- private[spark] def convertDataFrameToXGBLabeledPointRDDs(
- packedParams: PackedParams,
- dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {
-
- packedParams match {
- case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
- deterministicPartition) =>
- val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
- featuresCol,
- weight.cast(FloatType),
- groupCol.cast(IntegerType),
- baseMargin.cast(FloatType))).getOrElse(Seq(labelCol.cast(FloatType),
- featuresCol,
- weight.cast(FloatType),
- baseMargin.cast(FloatType)))
- val arrayOfRDDs = dataFrames.toArray.map {
- df => df.select(selectedColumns: _*).rdd.map {
- case row @ Row(label: Float, features: Vector, weight: Float, group: Int,
- baseMargin: Float) =>
- val (size, indices, values) = features match {
- case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
- case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
- }
- val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
- attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
- case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
- val (size, indices, values) = features match {
- case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
- case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
- }
- val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
- baseMargin = baseMargin)
- attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
- }
- }
- repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
-
- case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
- }
-
- }
-
- private[spark] def processMissingValues(
- xgbLabelPoints: Iterator[XGBLabeledPoint],
- missing: Float,
- allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
- if (!missing.isNaN) {
- removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
- missing, (v: Float) => v != missing)
- } else {
- removeMissingValues(verifyMissingSetting(xgbLabelPoints, missing, allowNonZeroMissing),
- missing, (v: Float) => !v.isNaN)
- }
- }
-
- private[spark] def processMissingValuesWithGroup(
- xgbLabelPointGroups: Iterator[Array[XGBLabeledPoint]],
- missing: Float,
- allowNonZeroMissing: Boolean): Iterator[Array[XGBLabeledPoint]] = {
- if (!missing.isNaN) {
- xgbLabelPointGroups.map {
- labeledPoints => processMissingValues(
- labeledPoints.iterator,
- missing,
- allowNonZeroMissing
- ).toArray
- }
- } else {
- xgbLabelPointGroups
- }
- }
-
- private def removeMissingValues(
- xgbLabelPoints: Iterator[XGBLabeledPoint],
- missing: Float,
- keepCondition: Float => Boolean): Iterator[XGBLabeledPoint] = {
- xgbLabelPoints.map { labeledPoint =>
- val indicesBuilder = new mutable.ArrayBuilder.ofInt()
- val valuesBuilder = new mutable.ArrayBuilder.ofFloat()
- for ((value, i) <- labeledPoint.values.zipWithIndex if keepCondition(value)) {
- indicesBuilder += (if (labeledPoint.indices == null) i else labeledPoint.indices(i))
- valuesBuilder += value
- }
- labeledPoint.copy(indices = indicesBuilder.result(), values = valuesBuilder.result())
- }
- }
-
- private def verifyMissingSetting(
- xgbLabelPoints: Iterator[XGBLabeledPoint],
- missing: Float,
- allowNonZeroMissing: Boolean): Iterator[XGBLabeledPoint] = {
- if (missing != 0.0f && !allowNonZeroMissing) {
- xgbLabelPoints.map(labeledPoint => {
- if (labeledPoint.indices != null) {
- throw new RuntimeException(s"you can only specify missing value as 0.0 (the currently" +
- s" set value $missing) when you have SparseVector or Empty vector as your feature" +
- s" format. If you didn't use Spark's VectorAssembler class to build your feature " +
- s"vector but instead did so in a way that preserves zeros in your feature vector " +
- s"you can avoid this check by using the 'allow_non_zero_for_missing parameter'" +
- s" (only use if you know what you are doing)")
- }
- labeledPoint
- })
- } else {
- xgbLabelPoints
- }
- }
-
-
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala
deleted file mode 100644
index ff732b78c08d..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- Copyright (c) 2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package org.apache.spark.ml.util
-
-import ml.dmlc.xgboost4j.java.{Booster => JBooster}
-import ml.dmlc.xgboost4j.scala.spark
-import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.fs.FSDataInputStream
-import org.json4s.DefaultFormats
-import org.json4s.JsonAST.JObject
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, render}
-
-import org.apache.spark.SparkContext
-import org.apache.spark.ml.param.Params
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
-
-abstract class XGBoostWriter extends MLWriter {
- def getModelFormat(): String = {
- optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
- }
-}
-
-object DefaultXGBoostParamsWriter {
-
- val XGBOOST_VERSION_TAG = "xgboostVersion"
-
- /**
- * Saves metadata + Params to: path + "/metadata" using [[DefaultParamsWriter.saveMetadata]]
- */
- def saveMetadata(
- instance: Params,
- path: String,
- sc: SparkContext): Unit = {
- // save xgboost version to distinguish the old model.
- val extraMetadata: JObject = Map(XGBOOST_VERSION_TAG -> ml.dmlc.xgboost4j.scala.spark.VERSION)
- DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
- }
-}
-
-object DefaultXGBoostParamsReader {
-
- private val logger = LogFactory.getLog("XGBoostSpark")
-
- /**
- * Load metadata saved using [[DefaultParamsReader.loadMetadata()]]
- *
- * @param expectedClassName If non empty, this is checked against the loaded metadata.
- * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
- */
- def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
- DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
- }
-
- /**
- * Extract Params from metadata, and set them in the instance.
- * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
- *
- * And it will auto-skip the parameter not defined.
- *
- * This API is mainly copied from DefaultParamsReader
- */
- def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
-
- // XGBoost didn't set the default parameters since the save/load code is copied
- // from spark 2.3.x, which means it just used the default values
- // as the same with XGBoost version instead of them in model.
- // For the compatibility, here we still don't set the default parameters.
- // setParams(instance, metadata, isDefault = true)
-
- setParams(instance, metadata, isDefault = false)
- }
-
- /** This API is only for XGBoostClassificationModel */
- def getNumClass(metadata: Metadata, dataInStream: FSDataInputStream): Int = {
- implicit val format = DefaultFormats
-
- // The xgboostVersion in the meta can specify if the model is the old xgboost in-compatible
- // or the new xgboost compatible.
- val xgbVerOpt = (metadata.metadata \ DefaultXGBoostParamsWriter.XGBOOST_VERSION_TAG)
- .extractOpt[String]
-
- // For binary:logistic, the numClass parameter can't be set to 2 or not be set.
- // For multi:softprob or multi:softmax, the numClass parameter must be set correctly,
- // or else, XGBoost will throw exception.
- // So it's safe to get numClass from meta data.
- xgbVerOpt
- .map { _ => (metadata.params \ "numClass").extractOpt[Int].getOrElse(2) }
- .getOrElse(dataInStream.readInt())
-
- }
-
- private def setParams(
- instance: Params,
- metadata: Metadata,
- isDefault: Boolean): Unit = {
- val paramsToSet = if (isDefault) metadata.defaultParams else metadata.params
- paramsToSet match {
- case JObject(pairs) =>
- pairs.foreach { case (paramName, jsonValue) =>
- val finalName = handleBrokenlyChangedName(paramName)
- // For the deleted parameters, we'd better to remove it instead of throwing an exception.
- // So we need to check if the parameter exists instead of blindly setting it.
- if (instance.hasParam(finalName)) {
- val param = instance.getParam(finalName)
- val value = param.jsonDecode(compact(render(jsonValue)))
- instance.set(param, handleBrokenlyChangedValue(paramName, value))
- } else {
- logger.warn(s"$finalName is no longer used in ${spark.VERSION}")
- }
- }
- case _ =>
- throw new IllegalArgumentException(
- s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
- }
- }
-
- private val paramNameCompatibilityMap: Map[String, String] = Map("silent" -> "verbosity")
-
- /** This is really not good to do this transformation, but it is needed since there're
- * some tests based on 0.82 saved model in which the objective is "reg:linear" */
- private val paramValueCompatibilityMap: Map[String, Map[Any, Any]] =
- Map("objective" -> Map("reg:linear" -> "reg:squarederror"))
-
- private def handleBrokenlyChangedName(paramName: String): String = {
- paramNameCompatibilityMap.getOrElse(paramName, paramName)
- }
-
- private def handleBrokenlyChangedValue[T](paramName: String, value: T): T = {
- paramValueCompatibilityMap.getOrElse(paramName, Map()).getOrElse(value, value).asInstanceOf[T]
- }
-
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala
deleted file mode 100644
index c013cfe66994..000000000000
--- a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostSchemaUtils.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- Copyright (c) 2022-2023 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package org.apache.spark.ml.util
-
-import org.apache.spark.sql.types.{BooleanType, DataType, NumericType, StructType}
-import org.apache.spark.ml.linalg.VectorUDT
-
-object XGBoostSchemaUtils {
-
- /** check if the dataType is VectorUDT */
- def isVectorUDFType(dataType: DataType): Boolean = {
- dataType match {
- case _: VectorUDT => true
- case _ => false
- }
- }
-
- /** The feature columns will be vectorized by VectorAssembler first, which only
- * supports Numeric, Boolean and VectorUDT types */
- def checkFeatureColumnType(dataType: DataType): Unit = {
- dataType match {
- case _: NumericType | BooleanType =>
- case _: VectorUDT =>
- case d => throw new UnsupportedOperationException(s"featuresCols only supports Numeric, " +
- s"boolean and VectorUDT types, found: ${d}")
- }
- }
-
- def checkNumericType(
- schema: StructType,
- colName: String,
- msg: String = ""): Unit = {
- SchemaUtils.checkNumericType(schema, colName, msg)
- }
-
-}
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala
new file mode 100644
index 000000000000..8bc88434a443
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/xgboost/SparkUtils.scala
@@ -0,0 +1,93 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package org.apache.spark.ml.xgboost
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.classification.ProbabilisticClassifierParams
+import org.apache.spark.ml.linalg.VectorUDT
+import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.util.{DatasetUtils, DefaultParamsReader, DefaultParamsWriter, SchemaUtils}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.json4s.{JObject, JValue}
+
+import ml.dmlc.xgboost4j.scala.spark.params.NonXGBoostParams
+
+/**
+ * XGBoost classification spark-specific parameters which should not be passed
+ * into the xgboost library
+ *
+ * @tparam T should be XGBoostClassifier or XGBoostClassificationModel
+ */
+trait XGBProbabilisticClassifierParams[T <: Params]
+ extends ProbabilisticClassifierParams with NonXGBoostParams {
+
+ /**
+ * XGBoost doesn't use validateAndTransformSchema since spark validateAndTransformSchema
+ * needs to ensure the feature is vector type
+ */
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ var outputSchema = SparkUtils.appendColumn(schema, $(predictionCol), DoubleType)
+ outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(rawPredictionCol))
+ outputSchema = SparkUtils.appendVectorUDTColumn(outputSchema, $(probabilityCol))
+ outputSchema
+ }
+
+ addNonXGBoostParam(rawPredictionCol, probabilityCol, thresholds)
+}
+
+/** Utils to access the spark internal functions */
+object SparkUtils {
+
+ def getNumClasses(dataset: Dataset[_], labelCol: String, maxNumClasses: Int = 100): Int = {
+ DatasetUtils.getNumClasses(dataset, labelCol, maxNumClasses)
+ }
+
+ def checkNumericType(schema: StructType, colName: String, msg: String = ""): Unit = {
+ SchemaUtils.checkNumericType(schema, colName, msg)
+ }
+
+ def saveMetadata(instance: Params,
+ path: String,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, paramMap)
+ }
+
+ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
+ DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+ }
+
+ def appendColumn(schema: StructType,
+ colName: String,
+ dataType: DataType,
+ nullable: Boolean = false): StructType = {
+ SchemaUtils.appendColumn(schema, colName, dataType, nullable)
+ }
+
+ def appendVectorUDTColumn(schema: StructType,
+ colName: String,
+ dataType: DataType = new VectorUDT,
+ nullable: Boolean = false): StructType = {
+ SchemaUtils.appendColumn(schema, colName, dataType, nullable)
+ }
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel
deleted file mode 100644
index 5d915d02f5f8..000000000000
Binary files a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/data/XGBoostClassificationModel and /dev/null differ
diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/_SUCCESS b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/_SUCCESS
deleted file mode 100644
index e69de29bb2d1..000000000000
diff --git a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000 b/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000
deleted file mode 100644
index 7e1a7221ace3..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/resources/model/0.82/model/metadata/part-00000
+++ /dev/null
@@ -1 +0,0 @@
-{"class":"ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel","timestamp":1555350539033,"sparkVersion":"2.3.2-uber-109","uid":"xgbc_5e7bec215a4c","paramMap":{"useExternalMemory":false,"trainTestRatio":1.0,"alpha":0.0,"seed":0,"numWorkers":100,"skipDrop":0.0,"treeLimit":0,"silent":0,"trackerConf":{"workerConnectionTimeout":0,"trackerImpl":"python"},"missing":"NaN","colsampleBylevel":1.0,"probabilityCol":"probability","checkpointPath":"","lambda":1.0,"rawPredictionCol":"rawPrediction","eta":0.3,"numEarlyStoppingRounds":0,"growPolicy":"depthwise","gamma":0.0,"sampleType":"uniform","maxDepth":6,"rateDrop":0.0,"objective":"reg:linear","customObj":null,"lambdaBias":0.0,"baseScore":0.5,"labelCol":"label","minChildWeight":1.0,"customEval":null,"normalizeType":"tree","maxBin":16,"nthread":4,"numRound":20,"colsampleBytree":1.0,"predictionCol":"prediction","subsample":1.0,"timeoutRequestWorkers":1800000,"featuresCol":"features","evalMetric":"error","sketchEps":0.03,"scalePosWeight":1.0,"checkpointInterval":-1,"maxDeltaStep":0.0,"treeMethod":"approx"}}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
index d3f3901ad704..37705d21b61d 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CommunicatorRobustnessSuite.scala
@@ -16,22 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark
-import java.util.concurrent.LinkedBlockingDeque
-
-import scala.util.Random
+import org.scalatest.funsuite.AnyFunSuite
import ml.dmlc.xgboost4j.java.{Communicator, RabitTracker}
-import ml.dmlc.xgboost4j.scala.DMatrix
-import org.scalatest.funsuite.AnyFunSuite
class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
- private def getXGBoostExecutionParams(paramMap: Map[String, Any]): XGBoostExecutionParams = {
- val classifier = new XGBoostClassifier(paramMap)
- val xgbParamsFactory = new XGBoostExecutionParamsFactory(classifier.MLlib2XGBoostParams, sc)
- xgbParamsFactory.buildXGBRuntimeParams
- }
-
test("test Java RabitTracker wrapper's exception handling: it should not hang forever.") {
/*
Deliberately create new instances of SparkContext in each unit test to avoid reusing the
@@ -113,9 +103,11 @@ class CommunicatorRobustnessSuite extends AnyFunSuite with PerTest {
"max_depth" -> "6",
"silent" -> "1",
"objective" -> "binary:logistic")
- val trainingDF = buildDataFrame(Classification.train)
- val model = new XGBoostClassifier(paramMap ++ Array("num_round" -> 10,
- "num_workers" -> numWorkers)).fit(trainingDF)
+ val trainingDF = smallBinaryClassificationVector
+ val model = new XGBoostClassifier(paramMap)
+ .setNumWorkers(numWorkers)
+ .setNumRound(10)
+ .fit(trainingDF)
val prediction = model.transform(trainingDF)
// a partial evaluation of dataframe will cause rabit initialized but not shutdown in some
// threads
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala
index b9a39a14d4f7..49d9d6d2c47b 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/CustomObj.scala
@@ -16,10 +16,12 @@
package ml.dmlc.xgboost4j.scala.spark
+import scala.collection.mutable.ListBuffer
+
+import org.apache.commons.logging.LogFactory
+
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, ObjectiveTrait}
-import org.apache.commons.logging.LogFactory
-import scala.collection.mutable.ListBuffer
/**
@@ -37,7 +39,7 @@ class CustomObj(val customParameter: Int = 0) extends ObjectiveTrait {
* @return List with two float array, correspond to first order grad and second order grad
*/
override def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix)
- : List[Array[Float]] = {
+ : List[Array[Float]] = {
val nrow = predicts.length
val gradients = new ListBuffer[Array[Float]]
var labels: Array[Float] = null
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala
deleted file mode 100644
index 8d9723bb62ef..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/DeterministicPartitioningSuite.scala
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import org.apache.spark.ml.linalg.Vectors
-import org.scalatest.funsuite.AnyFunSuite
-import ml.dmlc.xgboost4j.scala.spark.util.DataUtils
-import ml.dmlc.xgboost4j.scala.spark.util.DataUtils.PackedParams
-
-import org.apache.spark.sql.functions._
-
-class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
-
- test("perform deterministic partitioning when checkpointInternal and" +
- " checkpointPath is set (Classifier)") {
- val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
- val paramMap = Map("eta" -> "1", "max_depth" -> 2,
- "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
- "checkpoint_interval" -> 2, "num_workers" -> numWorkers)
- val xgbClassifier = new XGBoostClassifier(paramMap)
- assert(xgbClassifier.needDeterministicRepartitioning)
- }
-
- test("perform deterministic partitioning when checkpointInternal and" +
- " checkpointPath is set (Regressor)") {
- val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
- val paramMap = Map("eta" -> "1", "max_depth" -> 2,
- "objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
- "checkpoint_interval" -> 2, "num_workers" -> numWorkers)
- val xgbRegressor = new XGBoostRegressor(paramMap)
- assert(xgbRegressor.needDeterministicRepartitioning)
- }
-
- test("deterministic partitioning takes effect with various parts of data") {
- val trainingDF = buildDataFrame(Classification.train)
- // the test idea is that, we apply a chain of repartitions over trainingDFs but they
- // have to produce the identical RDDs
- val transformedDFs = (1 until 6).map(shuffleCount => {
- var resultDF = trainingDF
- for (i <- 0 until shuffleCount) {
- resultDF = resultDF.repartition(numWorkers)
- }
- resultDF
- })
- val transformedRDDs = transformedDFs.map(df => DataUtils.convertDataFrameToXGBLabeledPointRDDs(
- PackedParams(col("label"),
- col("features"),
- lit(1.0),
- lit(Float.NaN),
- None,
- numWorkers,
- deterministicPartition = true),
- df
- ).head)
- val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
- case (partitionIndex, labelPoints) =>
- Iterator((partitionIndex, labelPoints.toList))
- }.collect().toMap)
- resultsMaps.foldLeft(resultsMaps.head) { case (map1, map2) =>
- assert(map1.keys.toSet === map2.keys.toSet)
- for ((parIdx, labeledPoints) <- map1) {
- val sortedA = labeledPoints.sortBy(_.hashCode())
- val sortedB = map2(parIdx).sortBy(_.hashCode())
- assert(sortedA.length === sortedB.length)
- assert(sortedA.indices.forall(idx =>
- sortedA(idx).values.toSet === sortedB(idx).values.toSet))
- }
- map2
- }
- }
-
- test("deterministic partitioning has a uniform repartition on dataset with missing values") {
- val N = 10000
- val dataset = (0 until N).map{ n =>
- (n, n % 2, Vectors.sparse(3, Array(0, 1, 2), Array(Double.NaN, n, Double.NaN)))
- }
-
- val df = ss.createDataFrame(sc.parallelize(dataset)).toDF("id", "label", "features")
-
- val dfRepartitioned = DataUtils.convertDataFrameToXGBLabeledPointRDDs(
- PackedParams(col("label"),
- col("features"),
- lit(1.0),
- lit(Float.NaN),
- None,
- 10,
- deterministicPartition = true), df
- ).head
-
- val partitionsSizes = dfRepartitioned
- .mapPartitions(iter => Array(iter.size.toDouble).iterator, true)
- .collect()
- val partitionMean = partitionsSizes.sum / partitionsSizes.length
- val squaredDiffSum = partitionsSizes
- .map(partitionSize => Math.pow(partitionSize - partitionMean, 2))
- val standardDeviation = math.sqrt(squaredDiffSum.sum / squaredDiffSum.length)
-
- assert(standardDeviation < math.sqrt(N.toDouble))
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala
index 91a840911a32..04900f3d9b8c 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/EvalError.scala
@@ -16,9 +16,10 @@
package ml.dmlc.xgboost4j.scala.spark
+import org.apache.commons.logging.LogFactory
+
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.scala.{DMatrix, EvalTrait}
-import org.apache.commons.logging.LogFactory
class EvalError extends EvalTrait {
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
deleted file mode 100755
index 729bd9c77d1a..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ExternalCheckpointManagerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- Copyright (c) 2014-2023 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import java.io.File
-
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, ExternalCheckpointManager, XGBoost => SXGBoost}
-import org.scalatest.funsuite.AnyFunSuite
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
-
- private def produceParamMap(checkpointPath: String, checkpointInterval: Int):
- Map[String, Any] = {
- Map("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
- "objective" -> "binary:logistic", "num_workers" -> sc.defaultParallelism,
- "checkpoint_path" -> checkpointPath, "checkpoint_interval" -> checkpointInterval)
- }
-
- private def createNewModels():
- (String, XGBoostClassificationModel, XGBoostClassificationModel) = {
- val tmpPath = createTmpFolder("test").toAbsolutePath.toString
- val (model2, model4) = {
- val training = buildDataFrame(Classification.train)
- val paramMap = produceParamMap(tmpPath, 2)
- (new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
- new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
- }
- (tmpPath, model2, model4)
- }
-
- test("test update/load models") {
- val (tmpPath, model2, model4) = createNewModels()
- val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
-
- manager.updateCheckpoint(model2._booster.booster)
- var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
- assert(files.length == 1)
- assert(files.head.getPath.getName == "1.ubj")
- assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)
-
- manager.updateCheckpoint(model4._booster)
- files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
- assert(files.length == 1)
- assert(files.head.getPath.getName == "3.ubj")
- assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
- }
-
- test("test cleanUpHigherVersions") {
- val (tmpPath, model2, model4) = createNewModels()
-
- val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
- manager.updateCheckpoint(model4._booster)
- manager.cleanUpHigherVersions(3)
- assert(new File(s"$tmpPath/3.ubj").exists())
-
- manager.cleanUpHigherVersions(2)
- assert(!new File(s"$tmpPath/3.ubj").exists())
- }
-
- test("test checkpoint rounds") {
- import scala.collection.JavaConverters._
- val (tmpPath, model2, model4) = createNewModels()
- val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
- assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
- assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
- assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
- }
-
-
- private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
-
- val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
-
- val paramMap = produceParamMap(tmpPath, 2)
-
- val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
- val skipCleanCheckpointMap =
- if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
-
- val finalParamMap = paramMap ++ cacheDataMap ++ skipCleanCheckpointMap
-
- val prevModel = new XGBoostClassifier(finalParamMap ++ Seq("num_round" -> 5)).fit(training)
-
- def error(model: Booster): Float = eval.eval(model.predict(testDM, outPutMargin = true), testDM)
-
- if (skipCleanCheckpoint) {
- // Check only one model is kept after training
- val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
- assert(files.length == 1)
- assert(files.head.getPath.getName == "4.ubj")
- val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.ubj")
- // Train next model based on prev model
- val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
- assert(error(tmpModel) >= error(prevModel._booster))
- assert(error(prevModel._booster) > error(nextModel._booster))
- assert(error(nextModel._booster) < 0.1)
- } else {
- assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
- }
- }
-
- test("training with checkpoint boosters") {
- trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
- }
-
- test("training with checkpoint boosters with cached training dataset") {
- trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
- }
-
- test("the checkpoint file should be cleaned after a successful training") {
- trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala
deleted file mode 100644
index 789fd162bcbb..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import org.apache.spark.Partitioner
-import org.apache.spark.ml.feature.VectorAssembler
-import org.scalatest.funsuite.AnyFunSuite
-import org.apache.spark.sql.functions._
-
-import scala.util.Random
-
-class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
-
- test("transform throwing exception if feature size of dataset is greater than model's") {
- val modelPath = getClass.getResource("/model/0.82/model").getPath
- val model = XGBoostClassificationModel.read.load(modelPath)
- val r = new Random(0)
- // 0.82/model was trained with 251 features. and transform will throw exception
- // if feature size of data is not equal to 251
- var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
- toDF("feature", "label")
- for (x <- 1 to 252) {
- df = df.withColumn(s"feature_${x}", lit(1))
- }
- val assembler = new VectorAssembler()
- .setInputCols(df.columns.filter(!_.contains("label")))
- .setOutputCol("features")
- val thrown = intercept[Exception] {
- model.transform(assembler.transform(df)).show()
- }
- assert(thrown.getMessage.contains(
- "Number of columns does not match number of features in booster"))
- }
-
- test("train throwing exception if feature size of dataset is different on distributed train") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic",
- "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
- val sparkSession = ss
- import sparkSession.implicits._
- val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2)
- .map(lp => (lp.label, lp)).partitionBy(
- new Partitioner {
- override def numPartitions: Int = 2
-
- override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
- }
- ).map(_._2).zipWithIndex().map {
- case (lp, id) =>
- (id, lp.label, lp.features)
- }.toDF("id", "label", "features")
- val xgb = new XGBoostClassifier(paramMap)
- xgb.fit(repartitioned)
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
deleted file mode 100644
index 6a7f7129d56a..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/MissingValueHandlingSuite.scala
+++ /dev/null
@@ -1,235 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import org.apache.spark.ml.feature.VectorAssembler
-import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.sql.DataFrame
-import org.scalatest.funsuite.AnyFunSuite
-import scala.util.Random
-
-import org.apache.spark.SparkException
-
-class MissingValueHandlingSuite extends AnyFunSuite with PerTest {
- test("dense vectors containing missing value") {
- def buildDenseDataFrame(): DataFrame = {
- val numRows = 100
- val numCols = 5
- val data = (0 until numRows).map { x =>
- val label = Random.nextInt(2)
- val values = Array.tabulate[Double](numCols) { c =>
- if (c == numCols - 1) 0 else Random.nextDouble
- }
- (label, Vectors.dense(values))
- }
- ss.createDataFrame(sc.parallelize(data.toList)).toDF("label", "features")
- }
- val denseDF = buildDenseDataFrame().repartition(4)
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> 0, "num_workers" -> numWorkers).toMap
- val model = new XGBoostClassifier(paramMap).fit(denseDF)
- model.transform(denseDF).collect()
- }
-
- test("handle Float.NaN as missing value correctly") {
- val spark = ss
- import spark.implicits._
- val testDF = Seq(
- (1.0f, 0.0f, Float.NaN, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (1.0f, Float.NaN, 0.0f, 0.0),
- (0.0f, 1.0f, 0.0f, 1.0),
- (Float.NaN, 0.0f, 0.0f, 1.0)
- ).toDF("col1", "col2", "col3", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3"))
- .setOutputCol("features")
- .setHandleInvalid("keep")
-
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> Float.NaN, "num_workers" -> 1).toMap
- val model = new XGBoostClassifier(paramMap).fit(inputDF)
- model.transform(inputDF).collect()
- }
-
- test("specify a non-zero missing value but with dense vector does not stop" +
- " application") {
- val spark = ss
- import spark.implicits._
- // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
- // vector,
- val testDF = Seq(
- (1.0f, 0.0f, -1.0f, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (1.0f, -1.0f, 0.0f, 0.0),
- (0.0f, 1.0f, 0.0f, 1.0),
- (-1.0f, 0.0f, 0.0f, 1.0)
- ).toDF("col1", "col2", "col3", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3"))
- .setOutputCol("features")
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
- val model = new XGBoostClassifier(paramMap).fit(inputDF)
- model.transform(inputDF).collect()
- }
-
- test("specify a non-zero missing value and meet an empty vector we should" +
- " stop the application") {
- val spark = ss
- import spark.implicits._
- val testDF = Seq(
- (1.0f, 0.0f, -1.0f, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0),
- (1.0f, -1.0f, 0.0f, 0.0),
- (0.0f, 0.0f, 0.0f, 1.0),// empty vector
- (-1.0f, 0.0f, 0.0f, 1.0)
- ).toDF("col1", "col2", "col3", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3"))
- .setOutputCol("features")
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
- intercept[SparkException] {
- new XGBoostClassifier(paramMap).fit(inputDF)
- }
- }
-
- test("specify a non-zero missing value and meet a Sparse vector we should" +
- " stop the application") {
- val spark = ss
- import spark.implicits._
- // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
- // vector,
- val testDF = Seq(
- (1.0f, 0.0f, -1.0f, 1.0f, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 1.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (1.0f, -1.0f, 0.0f, 1.0f, 0.0),
- (0.0f, 0.0f, 0.0f, 1.0f, 1.0),
- (-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
- ).toDF("col1", "col2", "col3", "col4", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3", "col4"))
- .setOutputCol("features")
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- inputDF.show()
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> -1.0f, "num_workers" -> 1).toMap
- intercept[SparkException] {
- new XGBoostClassifier(paramMap).fit(inputDF)
- }
- }
-
- test("specify a non-zero missing value but set allow_non_zero_for_missing " +
- "does not stop application") {
- val spark = ss
- import spark.implicits._
- // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
- // vector,
- val testDF = Seq(
- (7.0f, 0.0f, -1.0f, 1.0f, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 1.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (1.0f, -1.0f, 0.0f, 1.0f, 0.0),
- (0.0f, 0.0f, 0.0f, 1.0f, 1.0),
- (-1.0f, 0.0f, 0.0f, 1.0f, 1.0)
- ).toDF("col1", "col2", "col3", "col4", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3", "col4"))
- .setOutputCol("features")
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- inputDF.show()
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> -1.0f,
- "num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
- val model = new XGBoostClassifier(paramMap).fit(inputDF)
- model.transform(inputDF).collect()
- }
-
- // https://github.com/dmlc/xgboost/pull/5929
- test("handle the empty last row correctly with a missing value as 0") {
- val spark = ss
- import spark.implicits._
- // spark uses 1.5 * (nnz + 1.0) < size as the condition to decide whether using sparse or dense
- // vector,
- val testDF = Seq(
- (7.0f, 0.0f, -1.0f, 1.0f, 1.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (0.0f, 1.0f, 0.0f, 1.0f, 0.0),
- (1.0f, 0.0f, 1.0f, 1.0f, 1.0),
- (1.0f, -1.0f, 0.0f, 1.0f, 0.0),
- (0.0f, 0.0f, 0.0f, 1.0f, 1.0),
- (0.0f, 0.0f, 0.0f, 0.0f, 0.0)
- ).toDF("col1", "col2", "col3", "col4", "label")
- val vectorAssembler = new VectorAssembler()
- .setInputCols(Array("col1", "col2", "col3", "col4"))
- .setOutputCol("features")
- val inputDF = vectorAssembler.transform(testDF).select("features", "label")
- inputDF.show()
- val paramMap = List("eta" -> "1", "max_depth" -> "2",
- "objective" -> "binary:logistic", "missing" -> 0.0f,
- "num_workers" -> 1, "allow_non_zero_for_missing" -> "true").toMap
- val model = new XGBoostClassifier(paramMap).fit(inputDF)
- model.transform(inputDF).collect()
- }
-
- test("Getter and setter for AllowNonZeroForMissingValue works") {
- {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
- val training = buildDataFrame(Classification.train)
- val classifier = new XGBoostClassifier(paramMap)
- classifier.setAllowNonZeroForMissing(true)
- assert(classifier.getAllowNonZeroForMissingValue)
- classifier.setAllowNonZeroForMissing(false)
- assert(!classifier.getAllowNonZeroForMissingValue)
- val model = classifier.fit(training)
- model.setAllowNonZeroForMissing(true)
- assert(model.getAllowNonZeroForMissingValue)
- model.setAllowNonZeroForMissing(false)
- assert(!model.getAllowNonZeroForMissingValue)
- }
-
- {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
- val training = buildDataFrame(Regression.train)
- val regressor = new XGBoostRegressor(paramMap)
- regressor.setAllowNonZeroForMissing(true)
- assert(regressor.getAllowNonZeroForMissingValue)
- regressor.setAllowNonZeroForMissing(false)
- assert(!regressor.getAllowNonZeroForMissingValue)
- val model = regressor.fit(training)
- model.setAllowNonZeroForMissing(true)
- assert(model.getAllowNonZeroForMissingValue)
- model.setAllowNonZeroForMissing(false)
- assert(!model.getAllowNonZeroForMissingValue)
- }
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
deleted file mode 100644
index 20a95f2a23e4..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import org.scalatest.BeforeAndAfterAll
-import org.scalatest.funsuite.AnyFunSuite
-
-import org.apache.spark.SparkException
-import org.apache.spark.ml.param.ParamMap
-
-class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
- test("XGBoost and Spark parameters synchronize correctly") {
- val xgbParamMap = Map("eta" -> "1", "objective" -> "binary:logistic",
- "objective_type" -> "classification")
- // from xgboost params to spark params
- val xgb = new XGBoostClassifier(xgbParamMap)
- assert(xgb.getEta === 1.0)
- assert(xgb.getObjective === "binary:logistic")
- assert(xgb.getObjectiveType === "classification")
- // from spark to xgboost params
- val xgbCopy = xgb.copy(ParamMap.empty)
- assert(xgbCopy.MLlib2XGBoostParams("eta").toString.toDouble === 1.0)
- assert(xgbCopy.MLlib2XGBoostParams("objective").toString === "binary:logistic")
- assert(xgbCopy.MLlib2XGBoostParams("objective_type").toString === "classification")
- val xgbCopy2 = xgb.copy(ParamMap.empty.put(xgb.evalMetric, "logloss"))
- assert(xgbCopy2.MLlib2XGBoostParams("eval_metric").toString === "logloss")
- }
-
- test("fail training elegantly with unsupported objective function") {
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "wrong_objective_function", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers)
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- intercept[SparkException] {
- xgb.fit(trainingDF)
- }
- }
-
- test("fail training elegantly with unsupported eval metrics") {
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "eval_metric" -> "wrong_eval_metrics")
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- intercept[SparkException] {
- xgb.fit(trainingDF)
- }
- }
-
- test("custom_eval does not support early stopping") {
- val paramMap = Map("eta" -> "0.1", "custom_eval" -> new EvalError, "silent" -> "1",
- "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
- val trainingDF = buildDataFrame(MultiClassification.train)
-
- val thrown = intercept[IllegalArgumentException] {
- new XGBoostClassifier(paramMap).fit(trainingDF)
- }
-
- assert(thrown.getMessage.contains("custom_eval does not support early stopping"))
- }
-
- test("early stopping should work without custom_eval setting") {
- val paramMap = Map("eta" -> "0.1", "silent" -> "1",
- "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "num_early_stopping_rounds" -> 2)
- val trainingDF = buildDataFrame(MultiClassification.train)
-
- new XGBoostClassifier(paramMap).fit(trainingDF)
- }
-
- test("Default parameters") {
- val classifier = new XGBoostClassifier()
- intercept[NoSuchElementException] {
- classifier.getBaseScore
- }
- }
-
- test("approx can't be used for gpu train") {
- val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- val thrown = intercept[IllegalArgumentException] {
- xgb.fit(trainingDF)
- }
- assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
- "for Spark GPU cluster"))
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
index 24bc00e1824e..49b50fcc469f 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014-2022 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,37 +18,39 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.{File, FileInputStream}
-import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
-
+import org.apache.commons.io.IOUtils
import org.apache.spark.SparkContext
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
-import scala.math.min
-import scala.util.Random
-import org.apache.commons.io.IOUtils
+import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
+import ml.dmlc.xgboost4j.scala.spark.Utils.{withResource, XGBLabeledPointFeatures}
-trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
+trait PerTest extends BeforeAndAfterEach {
+ self: AnyFunSuite =>
- protected val numWorkers: Int = min(Runtime.getRuntime.availableProcessors(), 4)
+ protected val numWorkers: Int = 4
@transient private var currentSession: SparkSession = _
def ss: SparkSession = getOrCreateSession
+
implicit def sc: SparkContext = ss.sparkContext
protected def sparkSessionBuilder: SparkSession.Builder = SparkSession.builder()
- .master(s"local[${numWorkers}]")
- .appName("XGBoostSuite")
- .config("spark.ui.enabled", false)
- .config("spark.driver.memory", "512m")
- .config("spark.barrier.sync.timeout", 10)
- .config("spark.task.cpus", 1)
+ .master(s"local[${numWorkers}]")
+ .appName("XGBoostSuite")
+ .config("spark.ui.enabled", false)
+ .config("spark.driver.memory", "512m")
+ .config("spark.barrier.sync.timeout", 10)
+ .config("spark.task.cpus", 1)
+ .config("spark.stage.maxConsecutiveAttempts", 1)
override def beforeEach(): Unit = getOrCreateSession
- override def afterEach() {
+ override def afterEach(): Unit = {
if (currentSession != null) {
currentSession.stop()
cleanExternalCache(currentSession.sparkContext.appName)
@@ -74,42 +76,25 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
protected def buildDataFrame(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
- (id, labeledPoint.label, labeledPoint.features)
+ (id, labeledPoint.label, labeledPoint.features, labeledPoint.weight)
}
-
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
- .toDF("id", "label", "features")
- }
-
- protected def buildDataFrameWithRandSort(
- labeledPoints: Seq[XGBLabeledPoint],
- numPartitions: Int = numWorkers): DataFrame = {
- val df = buildDataFrame(labeledPoints, numPartitions)
- val rndSortedRDD = df.rdd.mapPartitions { iter =>
- iter.map(_ -> Random.nextDouble()).toList
- .sortBy(_._2)
- .map(_._1).iterator
- }
- ss.createDataFrame(rndSortedRDD, df.schema)
+ .toDF("id", "label", "features", "weight")
}
protected def buildDataFrameWithGroup(
labeledPoints: Seq[XGBLabeledPoint],
numPartitions: Int = numWorkers): DataFrame = {
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
val it = labeledPoints.iterator.zipWithIndex
.map { case (labeledPoint: XGBLabeledPoint, id: Int) =>
- (id, labeledPoint.label, labeledPoint.features, labeledPoint.group)
+ (id, labeledPoint.label, labeledPoint.features, labeledPoint.group, labeledPoint.weight)
}
-
ss.createDataFrame(sc.parallelize(it.toList, numPartitions))
- .toDF("id", "label", "features", "group")
+ .toDF("id", "label", "features", "group", "weight")
}
-
protected def compareTwoFiles(lhs: String, rhs: String): Boolean = {
withResource(new FileInputStream(lhs)) { lfis =>
withResource(new FileInputStream(rhs)) { rfis =>
@@ -118,12 +103,32 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
}
}
- /** Executes the provided code block and then closes the resource */
- protected def withResource[T <: AutoCloseable, V](r: T)(block: T => V): V = {
- try {
- block(r)
- } finally {
- r.close()
- }
- }
+ def smallBinaryClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
+ (0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
+ (0.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
+ (1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
+ (0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
+ (1.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
+ ))).toDF("label", "margin", "weight", "features")
+
+ def smallMultiClassificationVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0)),
+ (0.0, 0.4, -3.0, Vectors.dense(0.0, 0.0, 0.0)),
+ (2.0, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
+ (1.0, 1.2, 0.2, Vectors.dense(2.0, 0.0, 4.0)),
+ (0.0, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0)),
+ (2.0, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7))
+ ))).toDF("label", "margin", "weight", "features")
+
+
+ def smallGroupVector: DataFrame = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 0.5, 2.0, Vectors.dense(1.0, 2.0, 3.0)),
+ (0.0, 1, 0.4, 1.0, Vectors.dense(0.0, 0.0, 0.0)),
+ (0.0, 1, 0.3, 1.0, Vectors.dense(0.0, 3.0, 0.0)),
+ (1.0, 0, 1.2, 2.0, Vectors.dense(2.0, 0.0, 4.0)),
+ (1.0, 2, -0.5, 3.0, Vectors.dense(0.2, 1.2, 2.0)),
+ (0.0, 2, -0.4, 3.0, Vectors.dense(0.5, 2.2, 1.7))
+ ))).toDF("label", "group", "margin", "weight", "features")
+
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala
deleted file mode 100755
index 5425b8647b09..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import java.io.File
-import java.util.Arrays
-
-import ml.dmlc.xgboost4j.scala.DMatrix
-
-import scala.util.Random
-import org.apache.spark.ml.feature._
-import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.sql.functions._
-import org.scalatest.funsuite.AnyFunSuite
-
-class PersistenceSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
-
- test("test persistence of XGBoostClassifier and XGBoostClassificationModel") {
- val eval = new EvalError()
- val trainingDF = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
-
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
- val xgbc = new XGBoostClassifier(paramMap)
- val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
- xgbc.write.overwrite().save(xgbcPath)
- val xgbc2 = XGBoostClassifier.load(xgbcPath)
- val paramMap2 = xgbc2.MLlib2XGBoostParams
- paramMap.foreach {
- case (k, v) => assert(v.toString == paramMap2(k).toString)
- }
-
- val model = xgbc.fit(trainingDF)
- val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults < 0.1)
- val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
- model.write.overwrite.save(xgbcModelPath)
- val model2 = XGBoostClassificationModel.load(xgbcModelPath)
- assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
-
- assert(model.getEta === model2.getEta)
- assert(model.getNumRound === model2.getNumRound)
- assert(model.getRawPredictionCol === model2.getRawPredictionCol)
- val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults === evalResults2)
- }
-
- test("test persistence of XGBoostRegressor and XGBoostRegressionModel") {
- val eval = new EvalError()
- val trainingDF = buildDataFrame(Regression.train)
- val testDM = new DMatrix(Regression.test.iterator)
-
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> "10", "num_workers" -> numWorkers)
- val xgbr = new XGBoostRegressor(paramMap)
- val xgbrPath = new File(tempDir.toFile, "xgbr").getPath
- xgbr.write.overwrite().save(xgbrPath)
- val xgbr2 = XGBoostRegressor.load(xgbrPath)
- val paramMap2 = xgbr2.MLlib2XGBoostParams
- paramMap.foreach {
- case (k, v) => assert(v.toString == paramMap2(k).toString)
- }
-
- val model = xgbr.fit(trainingDF)
- val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults < 0.1)
- val xgbrModelPath = new File(tempDir.toFile, "xgbrModel").getPath
- model.write.overwrite.save(xgbrModelPath)
- val model2 = XGBoostRegressionModel.load(xgbrModelPath)
- assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
-
- assert(model.getEta === model2.getEta)
- assert(model.getNumRound === model2.getNumRound)
- assert(model.getPredictionCol === model2.getPredictionCol)
- val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults === evalResults2)
- }
-
- test("test persistence of MLlib pipeline with XGBoostClassificationModel") {
- val r = new Random(0)
- // maybe move to shared context, but requires session to import implicits
- val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
- toDF("feature", "label")
-
- val assembler = new VectorAssembler()
- .setInputCols(df.columns.filter(!_.contains("label")))
- .setOutputCol("features")
-
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_round" -> "10", "num_workers" -> numWorkers)
- val xgb = new XGBoostClassifier(paramMap)
-
- // Construct MLlib pipeline, save and load
- val pipeline = new Pipeline().setStages(Array(assembler, xgb))
- val pipePath = new File(tempDir.toFile, "pipeline").getPath
- pipeline.write.overwrite().save(pipePath)
- val pipeline2 = Pipeline.read.load(pipePath)
- val xgb2 = pipeline2.getStages(1).asInstanceOf[XGBoostClassifier]
- val paramMap2 = xgb2.MLlib2XGBoostParams
- paramMap.foreach {
- case (k, v) => assert(v.toString == paramMap2(k).toString)
- }
-
- // Model training, save and load
- val pipeModel = pipeline.fit(df)
- val pipeModelPath = new File(tempDir.toFile, "pipelineModel").getPath
- pipeModel.write.overwrite.save(pipeModelPath)
- val pipeModel2 = PipelineModel.load(pipeModelPath)
-
- val xgbModel = pipeModel.stages(1).asInstanceOf[XGBoostClassificationModel]
- val xgbModel2 = pipeModel2.stages(1).asInstanceOf[XGBoostClassificationModel]
-
- assert(Arrays.equals(xgbModel._booster.toByteArray, xgbModel2._booster.toByteArray))
-
- assert(xgbModel.getEta === xgbModel2.getEta)
- assert(xgbModel.getNumRound === xgbModel2.getNumRound)
- assert(xgbModel.getRawPredictionCol === xgbModel2.getRawPredictionCol)
- }
-
- test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
- "using custom Eval and Obj") {
- val trainingDF = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "custom_eval" -> new EvalError, "custom_obj" -> new CustomObj(1),
- "num_round" -> "10", "num_workers" -> numWorkers, "objective" -> "binary:logistic")
-
- val xgbc = new XGBoostClassifier(paramMap)
- val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
- xgbc.write.overwrite().save(xgbcPath)
- val xgbc2 = XGBoostClassifier.load(xgbcPath)
- val paramMap2 = xgbc2.MLlib2XGBoostParams
- paramMap.foreach {
- case ("custom_eval", v) => assert(v.isInstanceOf[EvalError])
- case ("custom_obj", v) =>
- assert(v.isInstanceOf[CustomObj])
- assert(v.asInstanceOf[CustomObj].customParameter ==
- paramMap2("custom_obj").asInstanceOf[CustomObj].customParameter)
- case (_, _) =>
- }
-
- val eval = new EvalError()
-
- val model = xgbc.fit(trainingDF)
- val evalResults = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults < 0.1)
- val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
- model.write.overwrite.save(xgbcModelPath)
- val model2 = XGBoostClassificationModel.load(xgbcModelPath)
- assert(Arrays.equals(model._booster.toByteArray, model2._booster.toByteArray))
-
- assert(model.getEta === model2.getEta)
- assert(model.getNumRound === model2.getNumRound)
- assert(model.getRawPredictionCol === model2.getRawPredictionCol)
- val evalResults2 = eval.eval(model2._booster.predict(testDM, outPutMargin = true), testDM)
- assert(evalResults === evalResults2)
- }
-
- test("cross-version model loading (0.82)") {
- val modelPath = getClass.getResource("/model/0.82/model").getPath
- val model = XGBoostClassificationModel.read.load(modelPath)
- val r = new Random(0)
- var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))).
- toDF("feature", "label")
- // 0.82/model was trained with 251 features. and transform will throw exception
- // if feature size of data is not equal to 251
- for (x <- 1 to 250) {
- df = df.withColumn(s"feature_${x}", lit(1))
- }
- val assembler = new VectorAssembler()
- .setInputCols(df.columns.filter(!_.contains("label")))
- .setOutputCol("features")
- df = assembler.transform(df)
- for (x <- 1 to 250) {
- df = df.drop(s"feature_${x}")
- }
- model.transform(df).show()
- }
-}
-
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala
index fae241d8b990..b93bba9ef133 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2014 by Contributors
+ Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,8 +16,9 @@
package ml.dmlc.xgboost4j.scala.spark
-import scala.collection.mutable
import scala.io.Source
+import scala.util.Random
+
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
trait TrainTestData {
@@ -31,8 +32,8 @@ trait TrainTestData {
Source.fromInputStream(is).getLines()
}
- protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean):
- Seq[XGBLabeledPoint] = {
+ protected def getLabeledPoints(resource: String, featureSize: Int,
+ zeroBased: Boolean): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures.head.toFloat
@@ -65,10 +66,32 @@ trait TrainTestData {
object Classification extends TrainTestData {
val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false)
val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false)
+
+ Random.setSeed(10)
+ val randomWeights = Array.fill(train.length)(Random.nextFloat())
+ val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
+ XGBLabeledPoint(v.label, v.size, v.indices, v.values,
+ randomWeights(index), v.group, v.baseMargin)
+ }
}
object MultiClassification extends TrainTestData {
- val train: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
+
+ private def split(): (Seq[XGBLabeledPoint], Seq[XGBLabeledPoint]) = {
+ val tmp: Seq[XGBLabeledPoint] = getLabeledPoints("/dermatology.data")
+ Random.setSeed(100)
+ val randomizedTmp = Random.shuffle(tmp)
+ val splitIndex = (randomizedTmp.length * 0.8).toInt
+ (randomizedTmp.take(splitIndex), randomizedTmp.drop(splitIndex))
+ }
+
+ val (train, test) = split()
+ Random.setSeed(10)
+ val randomWeights = Array.fill(train.length)(Random.nextFloat())
+ val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
+ XGBLabeledPoint(v.label, v.size, v.indices, v.values,
+ randomWeights(index), v.group, v.baseMargin)
+ }
private def getLabeledPoints(resource: String): Seq[XGBLabeledPoint] = {
getResourceLines(resource).map { line =>
@@ -76,7 +99,7 @@ object MultiClassification extends TrainTestData {
val label = featuresAndLabel.last.toFloat - 1
val values = new Array[Float](featuresAndLabel.length - 1)
values(values.length - 1) =
- if (featuresAndLabel(featuresAndLabel.length - 2) == "?") 1 else 0
+ if (featuresAndLabel(featuresAndLabel.length - 2) == "?") 1 else 0
for (i <- 0 until values.length - 2) {
values(i) = featuresAndLabel(i).toFloat
}
@@ -92,31 +115,25 @@ object Regression extends TrainTestData {
"/machine.txt.train", MACHINE_COL_NUM, zeroBased = true)
val test: Seq[XGBLabeledPoint] = getLabeledPoints(
"/machine.txt.test", MACHINE_COL_NUM, zeroBased = true)
-}
-object Ranking extends TrainTestData {
- val RANK_COL_NUM = 3
- val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
- val test: Seq[XGBLabeledPoint] = getLabeledPoints(
- "/rank.test.txt", RANK_COL_NUM, zeroBased = false)
+ Random.setSeed(10)
+ val randomWeights = Array.fill(train.length)(Random.nextFloat())
+ val trainWithWeight = train.zipWithIndex.map { case (v, index) =>
+ XGBLabeledPoint(v.label, v.size, v.indices, v.values,
+ randomWeights(index), v.group, v.baseMargin)
+ }
- private def getGroups(resource: String): Seq[Int] = {
- getResourceLines(resource).map(_.toInt).toList
+ object Ranking extends TrainTestData {
+ val RANK_COL_NUM = 3
+ val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv")
+ // use the group as the weight
+ val trainWithWeight = train.map { labelPoint =>
+ XGBLabeledPoint(labelPoint.label, labelPoint.size, labelPoint.indices, labelPoint.values,
+ labelPoint.group, labelPoint.group, labelPoint.baseMargin)
+ }
+ val trainGroups = train.map(_.group)
+ val test: Seq[XGBLabeledPoint] = getLabeledPoints(
+ "/rank.test.txt", RANK_COL_NUM, zeroBased = false)
}
-}
-object Synthetic extends {
- val TRAIN_COL_NUM = 3
- val TRAIN_WRONG_COL_NUM = 2
- val train: Seq[XGBLabeledPoint] = Seq(
- XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
- XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
- XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)),
- XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f))
- )
-
- val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq(
- XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)),
- XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f))
- )
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
index 48e7dae52b2e..dcd22009514e 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala
@@ -16,465 +16,286 @@
package ml.dmlc.xgboost4j.scala.spark
-import java.io.{File, FileInputStream}
+import java.io.File
-import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
-
-import org.apache.spark.ml.linalg._
-import org.apache.spark.sql._
+import org.apache.spark.ml.linalg.DenseVector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
import org.scalatest.funsuite.AnyFunSuite
-import org.apache.commons.io.IOUtils
-import org.apache.spark.Partitioner
-import org.apache.spark.ml.feature.VectorAssembler
-import org.json4s.{DefaultFormats, Formats}
-import org.json4s.jackson.parseJson
+import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.{BINARY_CLASSIFICATION_OBJS, MULTICLASSIFICATION_OBJS}
+import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
- protected val treeMethod: String = "auto"
+ test("XGBoostClassifier copy") {
+ val classifier = new XGBoostClassifier().setNthread(2).setNumWorkers(10)
+ val classifierCopied = classifier.copy(ParamMap.empty)
- test("Set params in XGBoost and MLlib way should produce same model") {
- val trainingDF = buildDataFrame(Classification.train)
- val testDF = buildDataFrame(Classification.test)
- val round = 5
+ assert(classifier.uid === classifierCopied.uid)
+ assert(classifier.getNthread === classifierCopied.getNthread)
+ assert(classifier.getNumWorkers === classifier.getNumWorkers)
+ }
- val paramMap = Map(
- "eta" -> "1",
- "max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "binary:logistic",
- "num_round" -> round,
- "tree_method" -> treeMethod,
- "num_workers" -> numWorkers)
-
- // Set params in XGBoost way
- val model1 = new XGBoostClassifier(paramMap).fit(trainingDF)
- // Set params in MLlib way
- val model2 = new XGBoostClassifier()
- .setEta(1)
- .setMaxDepth(6)
- .setSilent(1)
- .setObjective("binary:logistic")
- .setNumRound(round)
- .setNumWorkers(numWorkers)
- .fit(trainingDF)
+ test("XGBoostClassification copy") {
+ val model = new XGBoostClassificationModel("hello").setNthread(2).setNumWorkers(10)
+ val modelCopied = model.copy(ParamMap.empty)
+ assert(model.uid === modelCopied.uid)
+ assert(model.getNthread === modelCopied.getNthread)
+ assert(model.getNumWorkers === modelCopied.getNumWorkers)
+ }
- val prediction1 = model1.transform(testDF).select("prediction").collect()
- val prediction2 = model2.transform(testDF).select("prediction").collect()
+ test("read/write") {
+ val trainDf = smallBinaryClassificationVector
+ val xgbParams: Map[String, Any] = Map(
+ "max_depth" -> 5,
+ "eta" -> 0.2,
+ "objective" -> "binary:logistic"
+ )
- prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(p1 === p2)
+ def check(xgboostParams: XGBoostParams[_]): Unit = {
+ assert(xgboostParams.getMaxDepth === 5)
+ assert(xgboostParams.getEta === 0.2)
+ assert(xgboostParams.getObjective === "binary:logistic")
}
- }
- test("test schema of XGBoostClassificationModel") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val trainingDF = buildDataFrame(Classification.train)
- val testDF = buildDataFrame(Classification.test)
+ val classifierPath = new File(tempDir.toFile, "classifier").getPath
+ val classifier = new XGBoostClassifier(xgbParams).setNumRound(2)
+ check(classifier)
- val model = new XGBoostClassifier(paramMap).fit(trainingDF)
-
- model.setRawPredictionCol("raw_prediction")
- .setProbabilityCol("probability_prediction")
- .setPredictionCol("final_prediction")
- var predictionDF = model.transform(testDF)
- assert(predictionDF.columns.contains("id"))
- assert(predictionDF.columns.contains("features"))
- assert(predictionDF.columns.contains("label"))
- assert(predictionDF.columns.contains("raw_prediction"))
- assert(predictionDF.columns.contains("probability_prediction"))
- assert(predictionDF.columns.contains("final_prediction"))
- model.setRawPredictionCol("").setPredictionCol("final_prediction")
- predictionDF = model.transform(testDF)
- assert(predictionDF.columns.contains("raw_prediction") === false)
- assert(predictionDF.columns.contains("final_prediction"))
- model.setRawPredictionCol("raw_prediction").setPredictionCol("")
- predictionDF = model.transform(testDF)
- assert(predictionDF.columns.contains("raw_prediction"))
- assert(predictionDF.columns.contains("final_prediction") === false)
-
- assert(model.summary.trainObjectiveHistory.length === 5)
- assert(model.summary.validationObjectiveHistory.isEmpty)
- }
+ classifier.write.overwrite().save(classifierPath)
+ val loadedClassifier = XGBoostClassifier.load(classifierPath)
+ check(loadedClassifier)
- test("multi class classification") {
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "multi:softmax", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(trainingDF)
- assert(model.getEta == 0.1)
- assert(model.getMaxDepth == 6)
- assert(model.numClasses == 6)
- val transformedDf = model.transform(trainingDF)
- assert(!transformedDf.columns.contains("probability"))
- }
+ val model = loadedClassifier.fit(trainDf)
+ check(model)
+ assert(model.numClasses === 2)
- test("objective will be set if not specifying it") {
- val training = buildDataFrame(Classification.train)
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val xgb = new XGBoostClassifier(paramMap)
- assert(!xgb.isDefined(xgb.objective))
- xgb.fit(training)
- assert(xgb.getObjective == "binary:logistic")
-
- val trainingDF = buildDataFrame(MultiClassification.train)
- val paramMap1 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val xgb1 = new XGBoostClassifier(paramMap1)
- assert(!xgb1.isDefined(xgb1.objective))
- xgb1.fit(trainingDF)
- assert(xgb1.getObjective == "multi:softprob")
-
- // shouldn't change user's objective setting
- val paramMap2 = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "num_class" -> "6", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod, "objective" -> "multi:softmax")
- val xgb2 = new XGBoostClassifier(paramMap2)
- assert(xgb2.getObjective == "multi:softmax")
- xgb2.fit(trainingDF)
- assert(xgb2.getObjective == "multi:softmax")
+ val modelPath = new File(tempDir.toFile, "model").getPath
+ model.write.overwrite().save(modelPath)
+ val modelLoaded = XGBoostClassificationModel.load(modelPath)
+ assert(modelLoaded.numClasses === 2)
+ check(modelLoaded)
}
- test("use base margin") {
- val training1 = buildDataFrame(Classification.train)
- val training2 = training1.withColumn("margin", functions.rand())
- val test = buildDataFrame(Classification.test)
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "1.0",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
-
- val xgb = new XGBoostClassifier(paramMap)
- val model1 = xgb.fit(training1)
- val model2 = xgb.setBaseMarginCol("margin").fit(training2)
- val prediction1 = model1.transform(test).select(model1.getProbabilityCol)
- .collect().map(row => row.getAs[Vector](0))
- val prediction2 = model2.transform(test).select(model2.getProbabilityCol)
- .collect().map(row => row.getAs[Vector](0))
- var count = 0
- for ((r1, r2) <- prediction1.zip(prediction2)) {
- if (!r1.equals(r2)) count = count + 1
+ test("XGBoostClassificationModel transformed schema") {
+ val trainDf = smallBinaryClassificationVector
+ val classifier = new XGBoostClassifier().setNumRound(1)
+ val model = classifier.fit(trainDf)
+ var out = model.transform(trainDf)
+
+ // Transform should not discard the other columns of the transforming dataframe
+ Seq("label", "margin", "weight", "features").foreach { v =>
+ assert(out.schema.names.contains(v))
}
- assert(count != 0)
- }
- test("test predictionLeaf") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Classification.train)
- val test = buildDataFrame(Classification.test)
- val groundTruth = test.count()
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("predictLeaf")
- val resultDF = model.transform(test)
- assert(resultDF.count == groundTruth)
- assert(resultDF.columns.contains("predictLeaf"))
- }
+ // Transform needs to add extra columns
+ Seq("rawPrediction", "probability", "prediction").foreach { v =>
+ assert(out.schema.names.contains(v))
+ }
+
+ assert(out.schema.names.length === 7)
+
+ model.setRawPredictionCol("").setProbabilityCol("")
+ out = model.transform(trainDf)
- test("test predictionLeaf with empty column name") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Classification.train)
- val test = buildDataFrame(Classification.test)
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("")
- val resultDF = model.transform(test)
- assert(!resultDF.columns.contains("predictLeaf"))
+ // rawPrediction="", probability=""
+ Seq("rawPrediction", "probability").foreach { v =>
+ assert(!out.schema.names.contains(v))
+ }
+
+ assert(out.schema.names.contains("prediction"))
+
+ model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
+ out = model.transform(trainDf)
+
+ assert(out.schema.names.contains("leaf"))
+ assert(out.schema.names.contains("contrib"))
+
+ val out1 = classifier.setLeafPredictionCol("leaf1")
+ .setContribPredictionCol("contrib1")
+ .train(trainDf).transform(trainDf)
+
+ assert(out1.schema.names.contains("leaf1"))
+ assert(out1.schema.names.contains("contrib1"))
}
- test("test predictionContrib") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Classification.train)
- val test = buildDataFrame(Classification.test)
- val groundTruth = test.count()
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- model.setContribPredictionCol("predictContrib")
- val resultDF = model.transform(buildDataFrame(Classification.test))
- assert(resultDF.count == groundTruth)
- assert(resultDF.columns.contains("predictContrib"))
+ test("Supported objectives") {
+ val classifier = new XGBoostClassifier()
+ val df = smallMultiClassificationVector
+ (BINARY_CLASSIFICATION_OBJS.toSeq ++ MULTICLASSIFICATION_OBJS.toSeq).foreach { obj =>
+ classifier.setObjective(obj)
+ classifier.validate(df)
+ }
+
+ classifier.setObjective("reg:squaredlogerror")
+ intercept[IllegalArgumentException](
+ classifier.validate(df)
+ )
}
- test("test predictionContrib with empty column name") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Classification.train)
- val test = buildDataFrame(Classification.test)
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- model.setContribPredictionCol("")
- val resultDF = model.transform(test)
- assert(!resultDF.columns.contains("predictContrib"))
+ test("Binaryclassification infer objective and num_class") {
+ val trainDf = smallBinaryClassificationVector
+ var classifier = new XGBoostClassifier()
+ assert(classifier.getObjective === "reg:squarederror")
+ assert(classifier.getNumClass === 0)
+ classifier.validate(trainDf)
+ assert(classifier.getObjective === "binary:logistic")
+ assert(!classifier.isSet(classifier.numClass))
+
+ // Infer objective according num class
+ classifier = new XGBoostClassifier()
+ classifier.setNumClass(2)
+ intercept[IllegalArgumentException](
+ classifier.validate(trainDf)
+ )
+
+ // Infer to num class according to num class
+ classifier = new XGBoostClassifier()
+ classifier.setObjective("binary:logistic")
+ classifier.validate(trainDf)
+ assert(classifier.getObjective === "binary:logistic")
+ assert(!classifier.isSet(classifier.numClass))
}
- test("test predictionLeaf and predictionContrib") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Classification.train)
- val test = buildDataFrame(Classification.test)
- val groundTruth = test.count()
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("predictLeaf")
- model.setContribPredictionCol("predictContrib")
- val resultDF = model.transform(buildDataFrame(Classification.test))
- assert(resultDF.count == groundTruth)
- assert(resultDF.columns.contains("predictLeaf"))
- assert(resultDF.columns.contains("predictContrib"))
+ test("MultiClassification infer objective and num_class") {
+ val trainDf = smallMultiClassificationVector
+ var classifier = new XGBoostClassifier()
+ assert(classifier.getObjective === "reg:squarederror")
+ assert(classifier.getNumClass === 0)
+ classifier.validate(trainDf)
+ assert(classifier.getObjective === "multi:softprob")
+ assert(classifier.getNumClass === 3)
+
+ // Infer to objective according to num class
+ classifier = new XGBoostClassifier()
+ classifier.setNumClass(3)
+ classifier.validate(trainDf)
+ assert(classifier.getObjective === "multi:softprob")
+ assert(classifier.getNumClass === 3)
+
+ // Infer to num class according to objective
+ classifier = new XGBoostClassifier()
+ classifier.setObjective("multi:softmax")
+ classifier.validate(trainDf)
+ assert(classifier.getObjective === "multi:softmax")
+ assert(classifier.getNumClass === 3)
}
- test("XGBoost-Spark XGBoostClassifier output should match XGBoost4j") {
+ test("XGBoost-Spark binary classification output should match XGBoost4j") {
val trainingDM = new DMatrix(Classification.train.iterator)
val testDM = new DMatrix(Classification.test.iterator)
val trainingDF = buildDataFrame(Classification.train)
val testDF = buildDataFrame(Classification.test)
- checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
+ val paramMap = Map("objective" -> "binary:logistic")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
}
- test("XGBoostClassifier should make correct predictions after upstream random sort") {
- val trainingDM = new DMatrix(Classification.train.iterator)
+ test("XGBoost-Spark binary classification output with weight should match XGBoost4j") {
+ val trainingDM = new DMatrix(Classification.trainWithWeight.iterator)
+ trainingDM.setWeight(Classification.randomWeights)
val testDM = new DMatrix(Classification.test.iterator)
- val trainingDF = buildDataFrameWithRandSort(Classification.train)
- val testDF = buildDataFrameWithRandSort(Classification.test)
- checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
+ val trainingDF = buildDataFrame(Classification.trainWithWeight)
+ val testDF = buildDataFrame(Classification.test)
+ val paramMap = Map("objective" -> "binary:logistic")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
+ 5, paramMap, Some("weight"))
+ }
+
+ Seq("multi:softprob", "multi:softmax").foreach { objective =>
+ test(s"XGBoost-Spark multi classification with $objective output should match XGBoost4j") {
+ val trainingDM = new DMatrix(MultiClassification.train.iterator)
+ val testDM = new DMatrix(MultiClassification.test.iterator)
+ val trainingDF = buildDataFrame(MultiClassification.train)
+ val testDF = buildDataFrame(MultiClassification.test)
+ val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
+ }
+ }
+
+ test("XGBoost-Spark multi classification output with weight should match XGBoost4j") {
+ val trainingDM = new DMatrix(MultiClassification.trainWithWeight.iterator)
+ trainingDM.setWeight(MultiClassification.randomWeights)
+ val testDM = new DMatrix(MultiClassification.test.iterator)
+ val trainingDF = buildDataFrame(MultiClassification.trainWithWeight)
+ val testDF = buildDataFrame(MultiClassification.test)
+ val paramMap = Map("objective" -> "multi:softprob", "num_class" -> 6)
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap, Some("weight"))
}
private def checkResultsWithXGBoost4j(
- trainingDM: DMatrix,
- testDM: DMatrix,
- trainingDF: DataFrame,
- testDF: DataFrame,
- round: Int = 5): Unit = {
+ trainingDM: DMatrix,
+ testDM: DMatrix,
+ trainingDF: DataFrame,
+ testDF: DataFrame,
+ round: Int = 5,
+ xgbParams: Map[String, Any] = Map.empty,
+ weightCol: Option[String] = None): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
- "silent" -> "1",
"base_score" -> 0.5,
- "objective" -> "binary:logistic",
- "tree_method" -> treeMethod,
- "max_bin" -> 16)
- val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
- val prediction1 = model1.predict(testDM)
-
- val model2 = new XGBoostClassifier(paramMap ++ Array("num_round" -> round,
- "num_workers" -> numWorkers)).fit(trainingDF)
-
- val prediction2 = model2.transform(testDF).
- collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("probability"))).toMap
-
- assert(testDF.count() === prediction2.size)
- // the vector length in probability column is 2 since we have to fit to the evaluator in Spark
- for (i <- prediction1.indices) {
- assert(prediction1(i).length === prediction2(i).values.length - 1)
- for (j <- prediction1(i).indices) {
- assert(prediction1(i)(j) === prediction2(i)(j + 1))
- }
- }
-
- val prediction3 = model1.predict(testDM, outPutMargin = true)
- val prediction4 = model2.transform(testDF).
- collect().map(row => (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction"))).toMap
+ "max_bin" -> 16) ++ xgbParams
+ val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
- assert(testDF.count() === prediction4.size)
- // the vector length in rawPrediction column is 2 since we have to fit to the evaluator in Spark
- for (i <- prediction3.indices) {
- assert(prediction3(i).length === prediction4(i).values.length - 1)
- for (j <- prediction3(i).indices) {
- assert(prediction3(i)(j) === prediction4(i)(j + 1))
+ val classifier = new XGBoostClassifier(paramMap)
+ .setNumRound(round)
+ .setNumWorkers(numWorkers)
+ .setLeafPredictionCol("leaf")
+ .setContribPredictionCol("contrib")
+ weightCol.foreach(weight => classifier.setWeightCol(weight))
+
+ def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
+ assert(left.size === right.size)
+ left.zipWithIndex.foreach { case (leftValue, index) =>
+ assert(leftValue.sameElements(right(index)))
}
}
- // check the equality of single instance prediction
- val firstOfDM = testDM.slice(Array(0))
- val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
- .head()
- .getAs[Vector]("features")
- val prediction5 = math.round(model1.predict(firstOfDM)(0)(0))
- val prediction6 = model2.predict(firstOfDF)
- assert(prediction5 === prediction6)
- }
-
- test("infrequent features") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic",
- "num_round" -> 5, "num_workers" -> 2, "missing" -> 0)
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
- val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
- new Partitioner {
- override def numPartitions: Int = 2
-
- override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
- }
- ).map(_._2).zipWithIndex().map {
- case (lp, id) =>
- (id, lp.label, lp.features)
- }.toDF("id", "label", "features")
- val xgb = new XGBoostClassifier(paramMap)
- xgb.fit(repartitioned)
- }
-
- test("infrequent features (use_external_memory)") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic",
- "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0)
- import ml.dmlc.xgboost4j.scala.spark.util.DataUtils._
- val sparkSession = SparkSession.builder().getOrCreate()
- import sparkSession.implicits._
- val repartitioned = sc.parallelize(Synthetic.train, 3).map(lp => (lp.label, lp)).partitionBy(
- new Partitioner {
- override def numPartitions: Int = 2
-
- override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt
+ val xgbSparkModel = classifier.fit(trainingDF)
+ val rows = xgbSparkModel.transform(testDF).collect()
+
+ // Check Leaf
+ val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
+ val xgbSparkLeaf = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jLeaf, xgbSparkLeaf)
+
+ // Check contrib
+ val xgb4jContrib = xgb4jModel.predictContrib(testDM)
+ val xgbSparkContrib = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jContrib, xgbSparkContrib)
+
+ def checkEqualForBinary(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
+ assert(left.size === right.size)
+ left.zipWithIndex.foreach { case (leftValue, index) =>
+ assert(leftValue.length === 1)
+ assert(leftValue.length === right(index).length - 1)
+ assert(leftValue(0) === right(index)(1))
}
- ).map(_._2).zipWithIndex().map {
- case (lp, id) =>
- (id, lp.label, lp.features)
- }.toDF("id", "label", "features")
- val xgb = new XGBoostClassifier(paramMap)
- xgb.fit(repartitioned)
- }
-
- test("featuresCols with features column can work") {
- val spark = ss
- import spark.implicits._
- val xgbInput = Seq(
- (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
- (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
- .toDF("f1", "f2", "f3", "features", "label")
-
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
-
- val featuresName = Array("f1", "f2", "f3", "features")
- val xgbClassifier = new XGBoostClassifier(paramMap)
- .setFeaturesCol(featuresName)
- .setLabelCol("label")
-
- val model = xgbClassifier.fit(xgbInput)
- assert(model.getFeaturesCols.sameElements(featuresName))
-
- val df = model.transform(xgbInput)
- assert(df.schema.fieldNames.contains("features_" + model.uid))
- df.show()
-
- val newFeatureName = "features_new"
- // transform also can work for vectorized dataset
- val vectorizedInput = new VectorAssembler()
- .setInputCols(featuresName)
- .setOutputCol(newFeatureName)
- .transform(xgbInput)
- .select(newFeatureName, "label")
-
- val df1 = model
- .setFeaturesCol(newFeatureName)
- .transform(vectorizedInput)
- assert(df1.schema.fieldNames.contains(newFeatureName))
- df1.show()
- }
+ }
- test("featuresCols without features column can work") {
- val spark = ss
- import spark.implicits._
- val xgbInput = Seq(
- (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
- (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
- .toDF("f1", "f2", "f3", "f4", "label")
-
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> 1)
-
- val featuresName = Array("f1", "f2", "f3", "f4")
- val xgbClassifier = new XGBoostClassifier(paramMap)
- .setFeaturesCol(featuresName)
- .setLabelCol("label")
- .setEvalSets(Map("eval" -> xgbInput))
-
- val model = xgbClassifier.fit(xgbInput)
- assert(model.getFeaturesCols.sameElements(featuresName))
-
- // transform should work for the dataset which includes the feature column names.
- val df = model.transform(xgbInput)
- assert(df.schema.fieldNames.contains("features"))
- df.show()
-
- // transform also can work for vectorized dataset
- val vectorizedInput = new VectorAssembler()
- .setInputCols(featuresName)
- .setOutputCol("features")
- .transform(xgbInput)
- .select("features", "label")
-
- val df1 = model.transform(vectorizedInput)
- df1.show()
- }
+ // Check probability
+ val xgb4jProb = xgb4jModel.predict(testDM)
+ val xgbSparkProb = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("probability").toArray.map(_.toFloat))).toMap
+ if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
+ checkEqualForBinary(xgb4jProb, xgbSparkProb)
+ } else {
+ checkEqual(xgb4jProb, xgbSparkProb)
+ }
- test("XGBoostClassificationModel should be compatible") {
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(trainingDF)
-
- // test json
- val modelPath = new File(tempDir.toFile, "xgbc").getPath
- model.write.option("format", "json").save(modelPath)
- val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
- model.nativeBooster.saveModel(nativeJsonModelPath)
- assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
- nativeJsonModelPath))
-
- // test ubj
- val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
- model.write.save(modelUbjPath)
- val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
- model.nativeBooster.saveModel(nativeUbjModelPath)
- assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
- nativeUbjModelPath))
-
- // json file should be indifferent with ubj file
- val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
- model.write.option("format", "json").save(modelJsonPath)
- val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
- model.nativeBooster.saveModel(nativeUbjModelPath1)
- assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
- nativeUbjModelPath1))
+ // Check rawPrediction
+ val xgb4jRawPred = xgb4jModel.predict(testDM, outPutMargin = true)
+ val xgbSparkRawPred = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("rawPrediction").toArray.map(_.toFloat))).toMap
+ if (BINARY_CLASSIFICATION_OBJS.contains(classifier.getObjective)) {
+ checkEqualForBinary(xgb4jRawPred, xgbSparkRawPred)
+ } else {
+ checkEqual(xgb4jRawPred, xgbSparkRawPred)
+ }
}
- test("native json model file should store feature_name and feature_type") {
- val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
- val featureTypes = (1 to 33).map(idx => "q").toArray
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
- "num_workers" -> numWorkers, "tree_method" -> treeMethod
- )
- val trainingDF = buildDataFrame(MultiClassification.train)
- val xgb = new XGBoostClassifier(paramMap)
- .setFeatureNames(featureNames)
- .setFeatureTypes(featureTypes)
- val model = xgb.fit(trainingDF)
- val modelStr = new String(model._booster.toByteArray("json"))
- val jsonModel = parseJson(modelStr)
- implicit val formats: Formats = DefaultFormats
- val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
- val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
- assert(featureNamesInModel.length == 33)
- assert(featureTypesInModel.length == 33)
- }
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala
deleted file mode 100644
index 136d39e8bc0f..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostCommunicatorRegressionSuite.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import ml.dmlc.xgboost4j.java.Communicator
-import ml.dmlc.xgboost4j.scala.Booster
-import scala.collection.JavaConverters._
-
-import org.apache.spark.sql._
-import org.scalatest.funsuite.AnyFunSuite
-
-import org.apache.spark.SparkException
-
-class XGBoostCommunicatorRegressionSuite extends AnyFunSuite with PerTest {
- val predictionErrorMin = 0.00001f
- val maxFailure = 2;
-
- override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
- .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .config("spark.kryo.classesToRegister", classOf[Booster].getName)
- .master(s"local[${numWorkers},${maxFailure}]")
-
- test("test classification prediction parity w/o ring reduce") {
- val training = buildDataFrame(Classification.train)
- val testDF = buildDataFrame(Classification.test)
-
- val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
-
- val model1 = new XGBoostClassifier(xgbSettings).fit(training)
- val prediction1 = model1.transform(testDF).select("prediction").collect()
-
- val model2 = new XGBoostClassifier(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1))
- .fit(training)
-
- val prediction2 = model2.transform(testDF).select("prediction").collect()
- // check parity w/o rabit cache
- prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(p1 == p2)
- }
- }
-
- test("test regression prediction parity w/o ring reduce") {
- val training = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers)
- val model1 = new XGBoostRegressor(xgbSettings).fit(training)
-
- val prediction1 = model1.transform(testDF).select("prediction").collect()
-
- val model2 = new XGBoostRegressor(xgbSettings ++ Map("rabit_ring_reduce_threshold" -> 1)
- ).fit(training)
- // check the equality of single instance prediction
- val prediction2 = model2.transform(testDF).select("prediction").collect()
- // check parity w/o rabit cache
- prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(math.abs(p1 - p2) < predictionErrorMin)
- }
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala
deleted file mode 100644
index 086fda2d7a1f..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostConfigureSuite.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import ml.dmlc.xgboost4j.scala.{Booster, DMatrix}
-
-import org.apache.spark.sql._
-import org.scalatest.funsuite.AnyFunSuite
-
-class XGBoostConfigureSuite extends AnyFunSuite with PerTest {
-
- override def sparkSessionBuilder: SparkSession.Builder = super.sparkSessionBuilder
- .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .config("spark.kryo.classesToRegister", classOf[Booster].getName)
-
- test("nthread configuration must be no larger than spark.task.cpus") {
- val training = buildDataFrame(Classification.train)
- val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
- "objective" -> "binary:logistic", "num_workers" -> numWorkers,
- "nthread" -> (sc.getConf.getInt("spark.task.cpus", 1) + 1))
- intercept[IllegalArgumentException] {
- new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training)
- }
- }
-
- test("kryoSerializer test") {
- // TODO write an isolated test for Booster.
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator, null)
- val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers)
-
- val model = new XGBoostClassifier(paramMap).fit(training)
- val eval = new EvalError()
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("Check for Spark encryption over-the-wire") {
- val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
- ss.conf.set("spark.ssl.enabled", true)
-
- val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
- "objective" -> "binary:logistic", "num_round" -> 2, "num_workers" -> numWorkers)
- val training = buildDataFrame(Classification.train)
-
- withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
- "xgboost.spark.ignoreSsl != true") {
- val thrown = intercept[Exception] {
- new XGBoostClassifier(paramMap).fit(training)
- }
- assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
- thrown.getMessage.contains("spark.ssl.enabled"))
- }
-
- // Confirm that this check can be overridden.
- ss.conf.set("xgboost.spark.ignoreSsl", true)
- new XGBoostClassifier(paramMap).fit(training)
-
- originalSslConfOpt match {
- case None =>
- ss.conf.unset("spark.ssl.enabled")
- case Some(originalSslConf) =>
- ss.conf.set("spark.ssl.enabled", originalSslConf)
- }
- ss.conf.unset("xgboost.spark.ignoreSsl")
- }
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala
new file mode 100644
index 000000000000..614e93c8e8cf
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala
@@ -0,0 +1,453 @@
+/*
+ Copyright (c) 2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.File
+import java.util.Arrays
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.ml.linalg.Vectors
+import org.json4s.{DefaultFormats, Formats}
+import org.json4s.jackson.parseJson
+import org.scalatest.funsuite.AnyFunSuite
+
+import ml.dmlc.xgboost4j.scala.DMatrix
+import ml.dmlc.xgboost4j.scala.spark.Utils.TRAIN_NAME
+
+class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
+
+ test("params") {
+ val df = smallBinaryClassificationVector
+ val xgbParams: Map[String, Any] = Map(
+ "max_depth" -> 5,
+ "eta" -> 0.2,
+ "objective" -> "binary:logistic"
+ )
+ val estimator = new XGBoostClassifier(xgbParams)
+ .setFeaturesCol("features")
+ .setMissing(0.2f)
+ .setAlpha(0.97)
+ .setLeafPredictionCol("leaf")
+ .setContribPredictionCol("contrib")
+ .setNumRound(1)
+
+ assert(estimator.getMaxDepth === 5)
+ assert(estimator.getEta === 0.2)
+ assert(estimator.getObjective === "binary:logistic")
+ assert(estimator.getFeaturesCol === "features")
+ assert(estimator.getMissing === 0.2f)
+ assert(estimator.getAlpha === 0.97)
+
+ estimator.setEta(0.66).setMaxDepth(7)
+ assert(estimator.getMaxDepth === 7)
+ assert(estimator.getEta === 0.66)
+
+ val model = estimator.train(df)
+ assert(model.getMaxDepth === 7)
+ assert(model.getEta === 0.66)
+ assert(model.getObjective === "binary:logistic")
+ assert(model.getFeaturesCol === "features")
+ assert(model.getMissing === 0.2f)
+ assert(model.getAlpha === 0.97)
+ assert(model.getLeafPredictionCol === "leaf")
+ assert(model.getContribPredictionCol === "contrib")
+ }
+
+ test("nthread") {
+ val classifier = new XGBoostClassifier().setNthread(100)
+
+ intercept[IllegalArgumentException](
+ classifier.validate(smallBinaryClassificationVector)
+ )
+ }
+
+ test("RuntimeParameter") {
+ var runtimeParams = new XGBoostClassifier(
+ Map("device" -> "cpu"))
+ .getRuntimeParameters(true)
+ assert(!runtimeParams.runOnGpu)
+
+ runtimeParams = new XGBoostClassifier(
+ Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
+ .getRuntimeParameters(true)
+ assert(runtimeParams.runOnGpu)
+
+ runtimeParams = new XGBoostClassifier(
+ Map("device" -> "cpu", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
+ .getRuntimeParameters(true)
+ assert(runtimeParams.runOnGpu)
+
+ runtimeParams = new XGBoostClassifier(
+ Map("device" -> "cuda", "tree_method" -> "gpu_hist")).setNumWorkers(1).setNumRound(1)
+ .getRuntimeParameters(true)
+ assert(runtimeParams.runOnGpu)
+ }
+
+ test("test persistence of XGBoostClassifier and XGBoostClassificationModel " +
+ "using custom Eval and Obj") {
+ val trainingDF = buildDataFrame(Classification.train)
+ val testDM = new DMatrix(Classification.test.iterator)
+
+ val paramMap = Map("eta" -> "0.1", "max_depth" -> "6",
+ "verbosity" -> "1", "objective" -> "binary:logistic")
+
+ val xgbc = new XGBoostClassifier(paramMap)
+ .setCustomObj(new CustomObj(1))
+ .setCustomEval(new EvalError)
+ .setNumRound(10)
+ .setNumWorkers(numWorkers)
+
+ val xgbcPath = new File(tempDir.toFile, "xgbc").getPath
+ xgbc.write.overwrite().save(xgbcPath)
+ val xgbc2 = XGBoostClassifier.load(xgbcPath)
+
+ assert(xgbc.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
+ assert(xgbc2.getCustomObj.asInstanceOf[CustomObj].customParameter === 1)
+
+ val eval = new EvalError()
+
+ val model = xgbc.fit(trainingDF)
+ val evalResults = eval.eval(model.nativeBooster.predict(testDM, outPutMargin = true), testDM)
+ assert(evalResults < 0.1)
+ val xgbcModelPath = new File(tempDir.toFile, "xgbcModel").getPath
+ model.write.overwrite.save(xgbcModelPath)
+ val model2 = XGBoostClassificationModel.load(xgbcModelPath)
+ assert(Arrays.equals(model.nativeBooster.toByteArray, model2.nativeBooster.toByteArray))
+
+ assert(model.getEta === model2.getEta)
+ assert(model.getNumRound === model2.getNumRound)
+ assert(model.getRawPredictionCol === model2.getRawPredictionCol)
+ val evalResults2 = eval.eval(model2.nativeBooster.predict(testDM, outPutMargin = true), testDM)
+ assert(evalResults === evalResults2)
+ }
+
+ test("Check for Spark encryption over-the-wire") {
+ val originalSslConfOpt = ss.conf.getOption("spark.ssl.enabled")
+ ss.conf.set("spark.ssl.enabled", true)
+
+ val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
+ "objective" -> "binary:logistic")
+ val training = smallBinaryClassificationVector
+
+ withClue("xgboost-spark should throw an exception when spark.ssl.enabled = true but " +
+ "xgboost.spark.ignoreSsl != true") {
+ val thrown = intercept[Exception] {
+ new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
+ }
+ assert(thrown.getMessage.contains("xgboost.spark.ignoreSsl") &&
+ thrown.getMessage.contains("spark.ssl.enabled"))
+ }
+
+ // Confirm that this check can be overridden.
+ ss.conf.set("xgboost.spark.ignoreSsl", true)
+ new XGBoostClassifier(paramMap).setNumRound(2).setNumWorkers(numWorkers).fit(training)
+
+ originalSslConfOpt match {
+ case None =>
+ ss.conf.unset("spark.ssl.enabled")
+ case Some(originalSslConf) =>
+ ss.conf.set("spark.ssl.enabled", originalSslConf)
+ }
+ ss.conf.unset("xgboost.spark.ignoreSsl")
+ }
+
+ test("nthread configuration must be no larger than spark.task.cpus") {
+ val training = smallBinaryClassificationVector
+ val paramMap = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1",
+ "objective" -> "binary:logistic")
+ intercept[IllegalArgumentException] {
+ new XGBoostClassifier(paramMap)
+ .setNumWorkers(numWorkers)
+ .setNumRound(2)
+ .setNthread(sc.getConf.getInt("spark.task.cpus", 1) + 1)
+ .fit(training)
+ }
+ }
+
+ test("preprocess dataset") {
+ val dataset = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 0.5, 1.0, Vectors.dense(1.0, 2.0, 3.0), "a"),
+ (0.0, 2, -0.5, 0.0, Vectors.dense(0.2, 1.2, 2.0), "b"),
+ (2.0, 2, -0.4, -2.1, Vectors.dense(0.5, 2.2, 1.7), "c")
+ ))).toDF("label", "group", "margin", "weight", "features", "other")
+
+ val classifier = new XGBoostClassifier()
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ .setBaseMarginCol("margin")
+ .setWeightCol("weight")
+
+ val (df, indices) = classifier.preprocess(dataset)
+ var schema = df.schema
+ assert(!schema.names.contains("group") && !schema.names.contains("other"))
+ assert(indices.labelId == schema.fieldIndex("label") &&
+ indices.groupId.isEmpty &&
+ indices.marginId.get == schema.fieldIndex("margin") &&
+ indices.weightId.get == schema.fieldIndex("weight") &&
+ indices.featureId.get == schema.fieldIndex("features") &&
+ indices.featureIds.isEmpty)
+
+ classifier.setWeightCol("")
+ val (df1, indices1) = classifier.preprocess(dataset)
+ schema = df1.schema
+ Seq("weight", "group", "other").foreach(v => assert(!schema.names.contains(v)))
+ assert(indices1.labelId == schema.fieldIndex("label") &&
+ indices1.groupId.isEmpty &&
+ indices1.marginId.get == schema.fieldIndex("margin") &&
+ indices1.weightId.isEmpty &&
+ indices1.featureId.get == schema.fieldIndex("features") &&
+ indices1.featureIds.isEmpty)
+ }
+
+ test("to XGBoostLabeledPoint RDD") {
+ val data = Array(
+ Array(1.0, 2.0, 3.0, 4.0, 5.0),
+ Array(0.0, 0.0, 0.0, 0.0, 2.0),
+ Array(12.0, 13.0, 14.0, 14.0, 15.0),
+ Array(20.5, 21.2, 0.0, 0.0, 2.0)
+ )
+ val dataset = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
+ (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
+ (3.0, 2, -0.5, 0.0, Vectors.dense(data(2)), "b"),
+ (4.0, 2, -0.4, -2.1, Vectors.dense(data(3)), "c")
+ ))).toDF("label", "group", "margin", "weight", "features", "other")
+
+ val classifier = new XGBoostClassifier()
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ .setWeightCol("weight")
+ .setNumWorkers(2)
+
+ val (df, indices) = classifier.preprocess(dataset)
+ val rdd = classifier.toXGBLabeledPoint(df, indices)
+ val result = rdd.collect().sortBy(x => x.label)
+
+ assert(result.length == data.length)
+
+ def toArray(index: Int): Array[Float] = {
+ val labelPoint = result(index)
+ if (labelPoint.indices != null) {
+ Vectors.sparse(labelPoint.size,
+ labelPoint.indices,
+ labelPoint.values.map(_.toDouble)).toArray.map(_.toFloat)
+ } else {
+ labelPoint.values
+ }
+ }
+
+ assert(result(0).label === 1.0f && result(0).baseMargin.isNaN &&
+ result(0).weight === 1.0f && toArray(0) === data(0).map(_.toFloat))
+ assert(result(1).label == 2.0f && result(1).baseMargin.isNaN &&
+ result(1).weight === 0.0f && toArray(1) === data(1).map(_.toFloat))
+ assert(result(2).label === 3.0f && result(2).baseMargin.isNaN &&
+ result(2).weight == 0.0f && toArray(2) === data(2).map(_.toFloat))
+ assert(result(3).label === 4.0f && result(3).baseMargin.isNaN &&
+ result(3).weight === -2.1f && toArray(3) === data(3).map(_.toFloat))
+ }
+
+ Seq((Float.NaN, 2), (0.0f, 7 + 2), (15.0f, 1 + 2), (10101011.0f, 0 + 2)).foreach {
+ case (missing, expectedMissingValue) =>
+ test(s"to RDD watches with missing $missing") {
+ val data = Array(
+ Array(1.0, 2.0, 3.0, 4.0, 5.0),
+ Array(1.0, Float.NaN, 0.0, 0.0, 2.0),
+ Array(12.0, 13.0, Float.NaN, 14.0, 15.0),
+ Array(0.0, 0.0, 0.0, 0.0, 0.0)
+ )
+ val dataset = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 0.5, 1.0, Vectors.dense(data(0)), "a"),
+ (2.0, 2, -0.5, 0.0, Vectors.dense(data(1)).toSparse, "b"),
+ (3.0, 3, -0.5, 0.0, Vectors.dense(data(2)), "b"),
+ (4.0, 4, -0.4, -2.1, Vectors.dense(data(3)), "c")
+ ))).toDF("label", "group", "margin", "weight", "features", "other")
+
+ val classifier = new XGBoostClassifier()
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ .setWeightCol("weight")
+ .setBaseMarginCol("margin")
+ .setMissing(missing)
+ .setNumWorkers(2)
+
+ val (df, indices) = classifier.preprocess(dataset)
+ val rdd = classifier.toRdd(df, indices)
+ val result = rdd.mapPartitions { iter =>
+ if (iter.hasNext) {
+ val watches = iter.next()
+ val size = watches.size
+ val trainDM = watches.toMap(TRAIN_NAME)
+ val rowNum = trainDM.rowNum
+ val labels = trainDM.getLabel
+ val weight = trainDM.getWeight
+ val margins = trainDM.getBaseMargin
+ val nonMissing = trainDM.nonMissingNum
+ watches.delete()
+ Iterator.single((size, rowNum, labels, weight, margins, nonMissing))
+ } else {
+ Iterator.empty
+ }
+ }.collect()
+
+ val labels: ArrayBuffer[Float] = ArrayBuffer.empty
+ val weight: ArrayBuffer[Float] = ArrayBuffer.empty
+ val margins: ArrayBuffer[Float] = ArrayBuffer.empty
+ var nonMissingValues = 0L
+ var totalRows = 0L
+
+ for (row <- result) {
+ assert(row._1 === 1)
+ totalRows = totalRows + row._2
+ labels.append(row._3: _*)
+ weight.append(row._4: _*)
+ margins.append(row._5: _*)
+ nonMissingValues = nonMissingValues + row._6
+ }
+ assert(totalRows === 4)
+ assert(nonMissingValues === data.size * data(0).length - expectedMissingValue)
+ assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
+ assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
+ assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
+ }
+ }
+
+ test("to RDD watches with eval") {
+ val trainData = Array(
+ Array(-1.0, -2.0, -3.0, -4.0, -5.0),
+ Array(2.0, 2.0, 2.0, 3.0, -2.0),
+ Array(-12.0, -13.0, -14.0, -14.0, -15.0),
+ Array(-20.5, -21.2, 0.0, 0.0, 2.0)
+ )
+ val trainDataset = ss.createDataFrame(sc.parallelize(Seq(
+ (11.0, 0, 0.15, 11.0, Vectors.dense(trainData(0)), "a"),
+ (12.0, 12, -0.15, 10.0, Vectors.dense(trainData(1)).toSparse, "b"),
+ (13.0, 12, -0.15, 10.0, Vectors.dense(trainData(2)), "b"),
+ (14.0, 12, -0.14, -12.1, Vectors.dense(trainData(3)), "c")
+ ))).toDF("label", "group", "margin", "weight", "features", "other")
+ val evalData = Array(
+ Array(1.0, 2.0, 3.0, 4.0, 5.0),
+ Array(0.0, 0.0, 0.0, 0.0, 2.0),
+ Array(12.0, 13.0, 14.0, 14.0, 15.0),
+ Array(20.5, 21.2, 0.0, 0.0, 2.0)
+ )
+ val evalDataset = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 0.5, 1.0, Vectors.dense(evalData(0)), "a"),
+ (2.0, 2, -0.5, 0.0, Vectors.dense(evalData(1)).toSparse, "b"),
+ (3.0, 2, -0.5, 0.0, Vectors.dense(evalData(2)), "b"),
+ (4.0, 2, -0.4, -2.1, Vectors.dense(evalData(3)), "c")
+ ))).toDF("label", "group", "margin", "weight", "features", "other")
+
+ val classifier = new XGBoostClassifier()
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ .setWeightCol("weight")
+ .setBaseMarginCol("margin")
+ .setEvalDataset(evalDataset)
+ .setNumWorkers(2)
+
+ val (df, indices) = classifier.preprocess(trainDataset)
+ val rdd = classifier.toRdd(df, indices)
+ val result = rdd.mapPartitions { iter =>
+ if (iter.hasNext) {
+ val watches = iter.next()
+ val size = watches.size
+ val evalDM = watches.toMap(Utils.VALIDATION_NAME)
+ val rowNum = evalDM.rowNum
+ val labels = evalDM.getLabel
+ val weight = evalDM.getWeight
+ val margins = evalDM.getBaseMargin
+ watches.delete()
+ Iterator.single((size, rowNum, labels, weight, margins))
+ } else {
+ Iterator.empty
+ }
+ }.collect()
+
+ val labels: ArrayBuffer[Float] = ArrayBuffer.empty
+ val weight: ArrayBuffer[Float] = ArrayBuffer.empty
+ val margins: ArrayBuffer[Float] = ArrayBuffer.empty
+
+ var totalRows = 0L
+ for (row <- result) {
+ assert(row._1 === 2)
+ totalRows = totalRows + row._2
+ labels.append(row._3: _*)
+ weight.append(row._4: _*)
+ margins.append(row._5: _*)
+ }
+ assert(totalRows === 4)
+ assert(labels.toArray.sorted === Array(1.0f, 2.0f, 3.0f, 4.0f).sorted)
+ assert(weight.toArray.sorted === Array(0.0f, 0.0f, 1.0f, -2.1f).sorted)
+ assert(margins.toArray.sorted === Array(-0.5f, -0.5f, -0.4f, 0.5f).sorted)
+ }
+
+ test("XGBoost-Spark model format should match xgboost4j") {
+ val trainingDF = buildDataFrame(MultiClassification.train)
+
+ Seq(new XGBoostClassifier()).foreach { est =>
+ est.setNumRound(5)
+ val model = est.fit(trainingDF)
+
+ // test json
+ val modelPath = new File(tempDir.toFile, "xgbc").getPath
+ model.write.overwrite().option("format", "json").save(modelPath)
+ val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
+ model.nativeBooster.saveModel(nativeJsonModelPath)
+ assert(compareTwoFiles(new File(modelPath, "data/model").getPath,
+ nativeJsonModelPath))
+
+ // test ubj
+ val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
+ model.write.overwrite().save(modelUbjPath)
+ val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
+ model.nativeBooster.saveModel(nativeUbjModelPath)
+ assert(compareTwoFiles(new File(modelUbjPath, "data/model").getPath,
+ nativeUbjModelPath))
+
+ // json file should be indifferent with ubj file
+ val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
+ model.write.overwrite().option("format", "json").save(modelJsonPath)
+ val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
+ model.nativeBooster.saveModel(nativeUbjModelPath1)
+ assert(!compareTwoFiles(new File(modelJsonPath, "data/model").getPath,
+ nativeUbjModelPath1))
+ }
+ }
+
+ test("native json model file should store feature_name and feature_type") {
+ val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
+ val featureTypes = (1 to 33).map(idx => "q").toArray
+ val trainingDF = buildDataFrame(MultiClassification.train)
+ val xgb = new XGBoostClassifier()
+ .setNumWorkers(numWorkers)
+ .setFeatureNames(featureNames)
+ .setFeatureTypes(featureTypes)
+ .setNumRound(2)
+ val model = xgb.fit(trainingDF)
+ val modelStr = new String(model.nativeBooster.toByteArray("json"))
+ val jsonModel = parseJson(modelStr)
+ implicit val formats: Formats = DefaultFormats
+ val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
+ val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
+ assert(featureNamesInModel.length == 33)
+ assert(featureTypesInModel.length == 33)
+ assert(featureNames sameElements featureNamesInModel)
+ assert(featureTypes sameElements featureTypesInModel)
+ }
+
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
deleted file mode 100755
index d93b182e043e..000000000000
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ /dev/null
@@ -1,376 +0,0 @@
-/*
- Copyright (c) 2014-2022 by Contributors
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
-
-package ml.dmlc.xgboost4j.scala.spark
-
-import scala.util.Random
-
-import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
-import ml.dmlc.xgboost4j.scala.DMatrix
-
-import org.apache.spark.{SparkException, TaskContext}
-import org.scalatest.funsuite.AnyFunSuite
-
-import org.apache.spark.ml.feature.VectorAssembler
-import org.apache.spark.sql.functions.lit
-
-class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {
-
- test("distributed training with the specified worker number") {
- val trainingRDD = sc.parallelize(Classification.train)
- val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD)
- val (booster, metrics) = XGBoost.trainDistributed(
- sc,
- buildTrainingRDD,
- List("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
- "custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
- "missing" -> Float.NaN).toMap)
- assert(booster != null)
- }
-
- test("training with external memory cache") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "num_round" -> 5, "num_workers" -> numWorkers,
- "use_external_memory" -> true)
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist with monotone_constraints (lossguide)") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1",
- "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
- "num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist with interaction_constraints (lossguide)") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1",
- "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
- "num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist with monotone_constraints (depthwise)") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1",
- "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
- "num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist with interaction_constraints (depthwise)") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1",
- "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
- "num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist depthwise") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1",
- "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
- "num_round" -> 5, "num_workers" -> numWorkers)
- val model = new XGBoostClassifier(paramMap).fit(training)
- assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
- }
-
- test("test with quantile hist lossguide") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
- "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
- "max_leaves" -> "8", "num_round" -> 5,
- "num_workers" -> numWorkers)
- val model = new XGBoostClassifier(paramMap).fit(training)
- val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(x < 0.1)
- }
-
- test("test with quantile hist lossguide with max bin") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0",
- "objective" -> "binary:logistic", "tree_method" -> "hist",
- "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
- "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> numWorkers)
- val model = new XGBoostClassifier(paramMap).fit(training)
- val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(x < 0.1)
- }
-
- test("test with quantile hist depthwidth with max depth") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist",
- "grow_policy" -> "depthwise", "max_depth" -> "2",
- "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
- val model = new XGBoostClassifier(paramMap).fit(training)
- val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(x < 0.1)
- }
-
- test("test with quantile hist depthwidth with max depth and max bin") {
- val eval = new EvalError()
- val training = buildDataFrame(Classification.train)
- val testDM = new DMatrix(Classification.test.iterator)
- val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6",
- "objective" -> "binary:logistic", "tree_method" -> "hist",
- "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
- "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
- val model = new XGBoostClassifier(paramMap).fit(training)
- val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
- assert(x < 0.1)
- }
-
- test("repartitionForTrainingGroup with group data") {
- // test different splits to cover the corner cases.
- for (split <- 1 to 20) {
- val trainingRDD = sc.parallelize(Ranking.train, split)
- val traingGroupsRDD = PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
- val trainingGroups: Array[Array[XGBLabeledPoint]] = traingGroupsRDD.collect()
- // check the the order of the groups with group id.
- // Ranking.train has 20 groups
- assert(trainingGroups.length == 20)
-
- // compare all points
- val allPoints = trainingGroups.sortBy(_(0).group).flatten
- assert(allPoints.length == Ranking.train.size)
- for (i <- 0 to Ranking.train.size - 1) {
- assert(allPoints(i).group == Ranking.train(i).group)
- assert(allPoints(i).label == Ranking.train(i).label)
- assert(allPoints(i).values.sameElements(Ranking.train(i).values))
- }
- }
- }
-
- test("repartitionForTrainingGroup with group data which has empty partition") {
- val trainingRDD = sc.parallelize(Ranking.train, 5).mapPartitions(it => {
- // make one partition empty for testing
- it.filter(_ => TaskContext.getPartitionId() != 3)
- })
- PreXGBoost.repartitionForTrainingGroup(trainingRDD, 4)
- }
-
- test("distributed training with group data") {
- val trainingRDD = sc.parallelize(Ranking.train, 5)
- val buildTrainingRDD = PreXGBoost.buildRDDLabeledPointToRDDWatches(trainingRDD, hasGroup = true)
- val (booster, _) = XGBoost.trainDistributed(
- sc,
- buildTrainingRDD,
- List("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers,
- "custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
- "missing" -> Float.NaN).toMap)
-
- assert(booster != null)
- }
-
- test("training summary") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "num_round" -> 5, "nWorkers" -> numWorkers)
-
- val trainingDF = buildDataFrame(Classification.train)
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(trainingDF)
-
- assert(model.summary.trainObjectiveHistory.length === 5)
- assert(model.summary.validationObjectiveHistory.isEmpty)
- }
-
- test("train/test split") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic", "train_test_ratio" -> "0.5",
- "num_round" -> 5, "num_workers" -> numWorkers)
- val training = buildDataFrame(Classification.train)
-
- val xgb = new XGBoostClassifier(paramMap)
- val model = xgb.fit(training)
- assert(model.summary.validationObjectiveHistory.length === 1)
- assert(model.summary.validationObjectiveHistory(0)._1 === "test")
- assert(model.summary.validationObjectiveHistory(0)._2.length === 5)
- assert(model.summary.trainObjectiveHistory !== model.summary.validationObjectiveHistory(0))
- }
-
- test("train with multiple validation datasets (non-ranking)") {
- val training = buildDataFrame(Classification.train)
- val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2))
- val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic",
- "num_round" -> 5, "num_workers" -> numWorkers)
-
- val xgb1 = new XGBoostClassifier(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
- val model1 = xgb1.fit(train)
- assert(model1.summary.validationObjectiveHistory.length === 2)
- assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
- assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
- assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
- assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
- assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
-
- val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "binary:logistic",
- "num_round" -> 5, "num_workers" -> numWorkers,
- "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
- val xgb2 = new XGBoostClassifier(paramMap2)
- val model2 = xgb2.fit(train)
- assert(model2.summary.validationObjectiveHistory.length === 2)
- assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
- assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
- assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
- assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
- assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
- }
-
- test("train with multiple validation datasets (ranking)") {
- val training = buildDataFrameWithGroup(Ranking.train, 5)
- val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0)
- val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:ndcg",
- "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
- val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
- val model1 = xgb1.fit(train)
- assert(model1 != null)
- assert(model1.summary.validationObjectiveHistory.length === 2)
- assert(model1.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
- assert(model1.summary.validationObjectiveHistory(0)._2.length === 5)
- assert(model1.summary.validationObjectiveHistory(1)._2.length === 5)
- assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(0))
- assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
-
- val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:ndcg",
- "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
- "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
- val xgb2 = new XGBoostRegressor(paramMap2)
- val model2 = xgb2.fit(train)
- assert(model2 != null)
- assert(model2.summary.validationObjectiveHistory.length === 2)
- assert(model2.summary.validationObjectiveHistory.map(_._1).toSet === Set("eval1", "eval2"))
- assert(model2.summary.validationObjectiveHistory(0)._2.length === 5)
- assert(model2.summary.validationObjectiveHistory(1)._2.length === 5)
- assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(0))
- assert(model2.summary.trainObjectiveHistory !== model2.summary.validationObjectiveHistory(1))
- }
-
- test("infer with different batch sizes") {
- val regModel = new XGBoostRegressor(Map(
- "eta" -> "1",
- "max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "reg:squarederror",
- "num_round" -> 5,
- "num_workers" -> numWorkers))
- .fit(buildDataFrame(Regression.train))
- val regDF = buildDataFrame(Regression.test)
-
- val regRet1 = regModel.transform(regDF).collect()
- val regRet2 = regModel.setInferBatchSize(1).transform(regDF).collect()
- val regRet3 = regModel.setInferBatchSize(10).transform(regDF).collect()
- val regRet4 = regModel.setInferBatchSize(32 << 15).transform(regDF).collect()
- assert(regRet1 sameElements regRet2)
- assert(regRet1 sameElements regRet3)
- assert(regRet1 sameElements regRet4)
-
- val clsModel = new XGBoostClassifier(Map(
- "eta" -> "1",
- "max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "binary:logistic",
- "num_round" -> 5,
- "num_workers" -> numWorkers))
- .fit(buildDataFrame(Classification.train))
- val clsDF = buildDataFrame(Classification.test)
-
- val clsRet1 = clsModel.transform(clsDF).collect()
- val clsRet2 = clsModel.setInferBatchSize(1).transform(clsDF).collect()
- val clsRet3 = clsModel.setInferBatchSize(10).transform(clsDF).collect()
- val clsRet4 = clsModel.setInferBatchSize(32 << 15).transform(clsDF).collect()
- assert(clsRet1 sameElements clsRet2)
- assert(clsRet1 sameElements clsRet3)
- assert(clsRet1 sameElements clsRet4)
- }
-
- test("chaining the prediction") {
- val modelPath = getClass.getResource("/model/0.82/model").getPath
- val model = XGBoostClassificationModel.read.load(modelPath)
- val r = new Random(0)
- var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))).
- toDF("feature", "label").repartition(5)
- // 0.82/model was trained with 251 features. and transform will throw exception
- // if feature size of data is not equal to 251
- for (x <- 1 to 250) {
- df = df.withColumn(s"feature_${x}", lit(1))
- }
- val assembler = new VectorAssembler()
- .setInputCols(df.columns.filter(!_.contains("label")))
- .setOutputCol("features")
- df = assembler.transform(df)
- for (x <- 1 to 250) {
- df = df.drop(s"feature_${x}")
- }
- val df1 = model.transform(df).withColumnRenamed(
- "prediction", "prediction1").withColumnRenamed(
- "rawPrediction", "rawPrediction1").withColumnRenamed(
- "probability", "probability1")
- val df2 = model.transform(df1)
- df1.collect()
- df2.collect()
- }
-
- test("throw exception for empty partition in trainingset") {
- val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "binary:logistic", "num_class" -> "2", "num_round" -> 5,
- "num_workers" -> numWorkers, "tree_method" -> "auto", "allow_non_zero_for_missing" -> true)
- // The Dmatrix will be empty
- val trainingDF = buildDataFrame(Seq(XGBLabeledPoint(1.0f, 4,
- Array(0, 1, 2, 3), Array(0, 1, 2, 3))))
- val xgb = new XGBoostClassifier(paramMap)
- intercept[SparkException] {
- xgb.fit(trainingDF)
- }
- }
-
-}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala
new file mode 100644
index 000000000000..035d2e7db815
--- /dev/null
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala
@@ -0,0 +1,289 @@
+/*
+ Copyright (c) 2014-2024 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala.spark
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.ml.linalg.{DenseVector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.scalatest.funsuite.AnyFunSuite
+
+import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
+import ml.dmlc.xgboost4j.scala.spark.Regression.Ranking
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.RANKER_OBJS
+import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
+
+class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
+
+ test("XGBoostRanker copy") {
+ val ranker = new XGBoostRanker().setNthread(2).setNumWorkers(10)
+ val rankertCopied = ranker.copy(ParamMap.empty)
+
+ assert(ranker.uid === rankertCopied.uid)
+ assert(ranker.getNthread === rankertCopied.getNthread)
+ assert(ranker.getNumWorkers === ranker.getNumWorkers)
+ }
+
+ test("XGBoostRankerModel copy") {
+ val model = new XGBoostRankerModel("hello").setNthread(2).setNumWorkers(10)
+ val modelCopied = model.copy(ParamMap.empty)
+ assert(model.uid === modelCopied.uid)
+ assert(model.getNthread === modelCopied.getNthread)
+ assert(model.getNumWorkers === modelCopied.getNumWorkers)
+ }
+
+ test("read/write") {
+ val trainDf = smallGroupVector
+ val xgbParams: Map[String, Any] = Map(
+ "max_depth" -> 5,
+ "eta" -> 0.2,
+ "objective" -> "rank:ndcg"
+ )
+
+ def check(xgboostParams: XGBoostParams[_]): Unit = {
+ assert(xgboostParams.getMaxDepth === 5)
+ assert(xgboostParams.getEta === 0.2)
+ assert(xgboostParams.getObjective === "rank:ndcg")
+ }
+
+ val rankerPath = new File(tempDir.toFile, "ranker").getPath
+ val ranker = new XGBoostRanker(xgbParams).setNumRound(1).setGroupCol("group")
+ check(ranker)
+ assert(ranker.getGroupCol === "group")
+
+ ranker.write.overwrite().save(rankerPath)
+ val loadedRanker = XGBoostRanker.load(rankerPath)
+ check(loadedRanker)
+ assert(loadedRanker.getGroupCol === "group")
+
+ val model = loadedRanker.fit(trainDf)
+ check(model)
+ assert(model.getGroupCol === "group")
+
+ val modelPath = new File(tempDir.toFile, "model").getPath
+ model.write.overwrite().save(modelPath)
+ val modelLoaded = XGBoostRankerModel.load(modelPath)
+ check(modelLoaded)
+ assert(modelLoaded.getGroupCol === "group")
+ }
+
+ test("validate") {
+ val trainDf = smallGroupVector
+ val ranker = new XGBoostRanker()
+ // must define group column
+ intercept[IllegalArgumentException](
+ ranker.validate(trainDf)
+ )
+ val ranker1 = new XGBoostRanker().setGroupCol("group")
+ ranker1.validate(trainDf)
+ assert(ranker1.getObjective === "rank:ndcg")
+ }
+
+ test("XGBoostRankerModel transformed schema") {
+ val trainDf = smallGroupVector
+ val ranker = new XGBoostRanker().setGroupCol("group").setNumRound(1)
+ val model = ranker.fit(trainDf)
+ var out = model.transform(trainDf)
+ // Transform should not discard the other columns of the transforming dataframe
+ Seq("label", "group", "margin", "weight", "features").foreach { v =>
+ assert(out.schema.names.contains(v))
+ }
+ // Ranker does not have extra columns
+ Seq("rawPrediction", "probability").foreach { v =>
+ assert(!out.schema.names.contains(v))
+ }
+ assert(out.schema.names.contains("prediction"))
+ assert(out.schema.names.length === 6)
+ model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
+ out = model.transform(trainDf)
+ assert(out.schema.names.contains("leaf"))
+ assert(out.schema.names.contains("contrib"))
+ }
+
+ test("Supported objectives") {
+ val ranker = new XGBoostRanker().setGroupCol("group")
+ val df = smallGroupVector
+ RANKER_OBJS.foreach { obj =>
+ ranker.setObjective(obj)
+ ranker.validate(df)
+ }
+
+ ranker.setObjective("binary:logistic")
+ intercept[IllegalArgumentException](
+ ranker.validate(df)
+ )
+ }
+
+ private def runLengthEncode(input: Seq[Int]): Seq[Int] = {
+ if (input.isEmpty) return Seq(0)
+
+ input.indices
+ .filter(i => i == 0 || input(i) != input(i - 1)) :+ input.length
+ }
+
+ private def runRanker(ranker: XGBoostRanker, dataset: Dataset[_]): (Array[Float], Array[Int]) = {
+ val (df, indices) = ranker.preprocess(dataset)
+ val rdd = ranker.toRdd(df, indices)
+ val result = rdd.mapPartitions { iter =>
+ if (iter.hasNext) {
+ val watches = iter.next()
+ val dm = watches.toMap(Utils.TRAIN_NAME)
+ val weight = dm.getWeight
+ val group = dm.getGroup
+ watches.delete()
+ Iterator.single((weight, group))
+ } else {
+ Iterator.empty
+ }
+ }.collect()
+
+ val weight: ArrayBuffer[Float] = ArrayBuffer.empty
+ val group: ArrayBuffer[Int] = ArrayBuffer.empty
+
+ for (row <- result) {
+ weight.append(row._1: _*)
+ group.append(row._2: _*)
+ }
+ (weight.toArray, group.toArray)
+ }
+
+ Seq(None, Some("weight")).foreach { weightCol => {
+ val msg = weightCol.map(_ => "with weight").getOrElse("without weight")
+ test(s"to RDD watches with group $msg") {
+ // One instance without setting weight
+ var df = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 0, 10, Vectors.dense(Array(1.0, 2.0, 3.0)))
+ ))).toDF("label", "group", "weight", "features")
+
+ val ranker = new XGBoostRanker()
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ .setGroupCol("group")
+ .setNumWorkers(1)
+
+ weightCol.foreach(ranker.setWeightCol)
+
+ val (weights, groupSize) = runRanker(ranker, df)
+ val expectedWeight = weightCol.map(_ => Array(10.0f)).getOrElse(Array(1.0f))
+ assert(weights === expectedWeight)
+ assert(groupSize === runLengthEncode(Seq(0)))
+
+ df = ss.createDataFrame(sc.parallelize(Seq(
+ (1.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
+ (2.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
+ (1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
+ (0.0, 1, 2, Vectors.dense(Array(1.0, 2.0, 3.0))),
+ (1.0, 0, 5, Vectors.dense(Array(1.0, 2.0, 3.0))),
+ (2.0, 2, 7, Vectors.dense(Array(1.0, 2.0, 3.0)))
+ ))).toDF("label", "group", "weight", "features")
+
+ val groups = Array(1, 1, 0, 1, 0, 2).sorted
+ val (weights1, groupSize1) = runRanker(ranker, df)
+ val expectedWeight1 = weightCol.map(_ => Array(5.0f, 2.0f, 7.0f))
+ .getOrElse(groups.distinct.map(_ => 1.0f))
+
+ assert(groupSize1 === runLengthEncode(groups))
+ assert(weights1 === expectedWeight1)
+ }
+ }
+ }
+
+ test("XGBoost-Spark output should match XGBoost4j") {
+ val trainingDM = new DMatrix(Ranking.train.iterator)
+ val weights = Ranking.trainGroups.distinct.map(_ => 1.0f).toArray
+ trainingDM.setQueryId(Ranking.trainGroups.toArray)
+ trainingDM.setWeight(weights)
+
+ val testDM = new DMatrix(Ranking.test.iterator)
+ val trainingDF = buildDataFrameWithGroup(Ranking.train)
+ val testDF = buildDataFrameWithGroup(Ranking.test)
+ val paramMap = Map("objective" -> "rank:ndcg")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
+ }
+
+ test("XGBoost-Spark output with weight should match XGBoost4j") {
+ val trainingDM = new DMatrix(Ranking.trainWithWeight.iterator)
+ trainingDM.setQueryId(Ranking.trainGroups.toArray)
+ trainingDM.setWeight(Ranking.trainGroups.distinct.map(_.toFloat).toArray)
+
+ val testDM = new DMatrix(Ranking.test.iterator)
+ val trainingDF = buildDataFrameWithGroup(Ranking.trainWithWeight)
+ val testDF = buildDataFrameWithGroup(Ranking.test)
+ val paramMap = Map("objective" -> "rank:ndcg")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
+ 5, paramMap, Some("weight"))
+ }
+
+ private def checkResultsWithXGBoost4j(
+ trainingDM: DMatrix,
+ testDM: DMatrix,
+ trainingDF: DataFrame,
+ testDF: DataFrame,
+ round: Int = 5,
+ xgbParams: Map[String, Any] = Map.empty,
+ weightCol: Option[String] = None): Unit = {
+ val paramMap = Map(
+ "eta" -> "1",
+ "max_depth" -> "6",
+ "base_score" -> 0.5,
+ "max_bin" -> 16) ++ xgbParams
+ val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
+
+ val ranker = new XGBoostRanker(paramMap)
+ .setNumRound(round)
+ // If we use multi workers to train the ranking, the result probably will be different
+ .setNumWorkers(1)
+ .setLeafPredictionCol("leaf")
+ .setContribPredictionCol("contrib")
+ .setGroupCol("group")
+ weightCol.foreach(weight => ranker.setWeightCol(weight))
+
+ def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
+ assert(left.size === right.size)
+ left.zipWithIndex.foreach { case (leftValue, index) =>
+ assert(leftValue.sameElements(right(index)))
+ }
+ }
+
+ val xgbSparkModel = ranker.fit(trainingDF)
+ val rows = xgbSparkModel.transform(testDF).collect()
+
+ // Check Leaf
+ val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
+ val xgbSparkLeaf = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jLeaf, xgbSparkLeaf)
+
+ // Check contrib
+ val xgb4jContrib = xgb4jModel.predictContrib(testDM)
+ val xgbSparkContrib = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jContrib, xgbSparkContrib)
+
+ // Check prediction
+ val xgb4jPred = xgb4jModel.predict(testDM)
+ val xgbSparkPred = rows.map(row => {
+ val pred = row.getAs[Double]("prediction").toFloat
+ (row.getAs[Int]("id"), Array(pred))
+ }).toMap
+ checkEqual(xgb4jPred, xgbSparkPred)
+ }
+
+}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
index 0698541c7e89..43209f1aff13 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
@@ -18,339 +18,168 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
-import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
-
-import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.ml.linalg.DenseVector
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
import org.scalatest.funsuite.AnyFunSuite
-import org.apache.spark.ml.feature.VectorAssembler
+import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
+import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams.REGRESSION_OBJS
+import ml.dmlc.xgboost4j.scala.spark.params.XGBoostParams
class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
- protected val treeMethod: String = "auto"
+ test("XGBoostRegressor copy") {
+ val regressor = new XGBoostRegressor().setNthread(2).setNumWorkers(10)
+ val regressortCopied = regressor.copy(ParamMap.empty)
- test("XGBoost-Spark XGBoostRegressor output should match XGBoost4j") {
- val trainingDM = new DMatrix(Regression.train.iterator)
- val testDM = new DMatrix(Regression.test.iterator)
- val trainingDF = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
+ assert(regressor.uid === regressortCopied.uid)
+ assert(regressor.getNthread === regressortCopied.getNthread)
+ assert(regressor.getNumWorkers === regressor.getNumWorkers)
}
- test("XGBoostRegressor should make correct predictions after upstream random sort") {
- val trainingDM = new DMatrix(Regression.train.iterator)
- val testDM = new DMatrix(Regression.test.iterator)
- val trainingDF = buildDataFrameWithRandSort(Regression.train)
- val testDF = buildDataFrameWithRandSort(Regression.test)
- checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF)
+ test("XGBoostRegressionModel copy") {
+ val model = new XGBoostRegressionModel("hello").setNthread(2).setNumWorkers(10)
+ val modelCopied = model.copy(ParamMap.empty)
+ assert(model.uid === modelCopied.uid)
+ assert(model.getNthread === modelCopied.getNthread)
+ assert(model.getNumWorkers === modelCopied.getNumWorkers)
}
- private def checkResultsWithXGBoost4j(
- trainingDM: DMatrix,
- testDM: DMatrix,
- trainingDF: DataFrame,
- testDF: DataFrame,
- round: Int = 5): Unit = {
- val paramMap = Map(
- "eta" -> "1",
- "max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "reg:squarederror",
- "max_bin" -> 64,
- "tree_method" -> treeMethod)
-
- val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
- val prediction1 = model1.predict(testDM)
-
- val model2 = new XGBoostRegressor(paramMap ++ Array("num_round" -> round,
- "num_workers" -> numWorkers)).fit(trainingDF)
+ test("read/write") {
+ val trainDf = smallBinaryClassificationVector
+ val xgbParams: Map[String, Any] = Map(
+ "max_depth" -> 5,
+ "eta" -> 0.2
+ )
+
+ def check(xgboostParams: XGBoostParams[_]): Unit = {
+ assert(xgboostParams.getMaxDepth === 5)
+ assert(xgboostParams.getEta === 0.2)
+ assert(xgboostParams.getObjective === "reg:squarederror")
+ }
- val prediction2 = model2.transform(testDF).
- collect().map(row => (row.getAs[Int]("id"), row.getAs[Double]("prediction"))).toMap
+ val regressorPath = new File(tempDir.toFile, "regressor").getPath
+ val regressor = new XGBoostRegressor(xgbParams).setNumRound(1)
+ check(regressor)
- assert(prediction1.indices.count { i =>
- math.abs(prediction1(i)(0) - prediction2(i)) > 0.01
- } < prediction1.length * 0.1)
+ regressor.write.overwrite().save(regressorPath)
+ val loadedRegressor = XGBoostRegressor.load(regressorPath)
+ check(loadedRegressor)
+ val model = loadedRegressor.fit(trainDf)
+ check(model)
- // check the equality of single instance prediction
- val firstOfDM = testDM.slice(Array(0))
- val firstOfDF = testDF.filter(_.getAs[Int]("id") == 0)
- .head()
- .getAs[Vector]("features")
- val prediction3 = model1.predict(firstOfDM)(0)(0)
- val prediction4 = model2.predict(firstOfDF)
- assert(math.abs(prediction3 - prediction4) <= 0.01f)
+ val modelPath = new File(tempDir.toFile, "model").getPath
+ model.write.overwrite().save(modelPath)
+ val modelLoaded = XGBoostRegressionModel.load(modelPath)
+ check(modelLoaded)
}
- test("Set params in XGBoost and MLlib way should produce same model") {
- val trainingDF = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- val round = 5
-
- val paramMap = Map(
- "eta" -> "1",
- "max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "reg:squarederror",
- "num_round" -> round,
- "tree_method" -> treeMethod,
- "num_workers" -> numWorkers)
-
- // Set params in XGBoost way
- val model1 = new XGBoostRegressor(paramMap).fit(trainingDF)
- // Set params in MLlib way
- val model2 = new XGBoostRegressor()
- .setEta(1)
- .setMaxDepth(6)
- .setSilent(1)
- .setObjective("reg:squarederror")
- .setNumRound(round)
- .setTreeMethod(treeMethod)
- .setNumWorkers(numWorkers)
- .fit(trainingDF)
-
- val prediction1 = model1.transform(testDF).select("prediction").collect()
- val prediction2 = model2.transform(testDF).select("prediction").collect()
-
- prediction1.zip(prediction2).foreach { case (Row(p1: Double), Row(p2: Double)) =>
- assert(math.abs(p1 - p2) <= 0.01f)
+ test("XGBoostRegressionModel transformed schema") {
+ val trainDf = smallBinaryClassificationVector
+ val regressor = new XGBoostRegressor().setNumRound(1)
+ val model = regressor.fit(trainDf)
+ var out = model.transform(trainDf)
+ // Transform should not discard the other columns of the transforming dataframe
+ Seq("label", "margin", "weight", "features").foreach { v =>
+ assert(out.schema.names.contains(v))
}
+ // Regressor does not have extra columns
+ Seq("rawPrediction", "probability").foreach { v =>
+ assert(!out.schema.names.contains(v))
+ }
+ assert(out.schema.names.contains("prediction"))
+ assert(out.schema.names.length === 5)
+ model.setLeafPredictionCol("leaf").setContribPredictionCol("contrib")
+ out = model.transform(trainDf)
+ assert(out.schema.names.contains("leaf"))
+ assert(out.schema.names.contains("contrib"))
}
- test("ranking: use group data") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
- "group_col" -> "group", "tree_method" -> treeMethod)
-
- val trainingDF = buildDataFrameWithGroup(Ranking.train)
- val testDF = buildDataFrame(Ranking.test)
- val model = new XGBoostRegressor(paramMap).fit(trainingDF)
+ test("Supported objectives") {
+ val regressor = new XGBoostRegressor()
+ val df = smallMultiClassificationVector
+ REGRESSION_OBJS.foreach { obj =>
+ regressor.setObjective(obj)
+ regressor.validate(df)
+ }
- val prediction = model.transform(testDF).collect()
- assert(testDF.count() === prediction.length)
+ regressor.setObjective("binary:logistic")
+ intercept[IllegalArgumentException](
+ regressor.validate(df)
+ )
}
- test("use weight") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
-
- val getWeightFromId = udf({id: Int => if (id == 0) 1.0f else 0.001f})
+ test("XGBoost-Spark output should match XGBoost4j") {
+ val trainingDM = new DMatrix(Regression.train.iterator)
+ val testDM = new DMatrix(Regression.test.iterator)
val trainingDF = buildDataFrame(Regression.train)
- .withColumn("weight", getWeightFromId(col("id")))
- val testDF = buildDataFrame(Regression.test)
-
- val model = new XGBoostRegressor(paramMap).setWeightCol("weight").fit(trainingDF)
- val prediction = model.transform(testDF).collect()
- val first = prediction.head.getAs[Double]("prediction")
- prediction.foreach(x => assert(math.abs(x.getAs[Double]("prediction") - first) <= 0.01f))
- }
-
- test("objective will be set if not specifying it") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
- val xgb = new XGBoostRegressor(paramMap)
- assert(!xgb.isDefined(xgb.objective))
- xgb.fit(training)
- assert(xgb.getObjective == "reg:squarederror")
-
- val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "num_round" -> 5, "num_workers" -> numWorkers, "tree_method" -> treeMethod,
- "objective" -> "reg:squaredlogerror")
- val xgb1 = new XGBoostRegressor(paramMap1)
- assert(xgb1.getObjective == "reg:squaredlogerror")
- xgb1.fit(training)
- assert(xgb1.getObjective == "reg:squaredlogerror")
- }
-
- test("test predictionLeaf") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- val groundTruth = testDF.count()
- val xgb = new XGBoostRegressor(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("predictLeaf")
- val resultDF = model.transform(testDF)
- assert(resultDF.count === groundTruth)
- assert(resultDF.columns.contains("predictLeaf"))
- }
-
- test("test predictionLeaf with empty column name") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
val testDF = buildDataFrame(Regression.test)
- val xgb = new XGBoostRegressor(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("")
- val resultDF = model.transform(testDF)
- assert(!resultDF.columns.contains("predictLeaf"))
+ val paramMap = Map("objective" -> "reg:squarederror")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF, 5, paramMap)
}
- test("test predictionContrib") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- val groundTruth = testDF.count()
- val xgb = new XGBoostRegressor(paramMap)
- val model = xgb.fit(training)
- model.setContribPredictionCol("predictContrib")
- val resultDF = model.transform(testDF)
- assert(resultDF.count === groundTruth)
- assert(resultDF.columns.contains("predictContrib"))
- }
-
- test("test predictionContrib with empty column name") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
- val testDF = buildDataFrame(Regression.test)
- val xgb = new XGBoostRegressor(paramMap)
- val model = xgb.fit(training)
- model.setContribPredictionCol("")
- val resultDF = model.transform(testDF)
- assert(!resultDF.columns.contains("predictContrib"))
- }
-
- test("test predictionLeaf and predictionContrib") {
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
- "tree_method" -> treeMethod)
- val training = buildDataFrame(Regression.train)
+ test("XGBoost-Spark output with weight should match XGBoost4j") {
+ val trainingDM = new DMatrix(Regression.trainWithWeight.iterator)
+ trainingDM.setWeight(Regression.randomWeights)
+ val testDM = new DMatrix(Regression.test.iterator)
+ val trainingDF = buildDataFrame(Regression.trainWithWeight)
val testDF = buildDataFrame(Regression.test)
- val groundTruth = testDF.count()
- val xgb = new XGBoostRegressor(paramMap)
- val model = xgb.fit(training)
- model.setLeafPredictionCol("predictLeaf")
- model.setContribPredictionCol("predictContrib")
- val resultDF = model.transform(testDF)
- assert(resultDF.count === groundTruth)
- assert(resultDF.columns.contains("predictLeaf"))
- assert(resultDF.columns.contains("predictContrib"))
+ val paramMap = Map("objective" -> "reg:squarederror")
+ checkResultsWithXGBoost4j(trainingDM, testDM, trainingDF, testDF,
+ 5, paramMap, Some("weight"))
}
- test("featuresCols with features column can work") {
- val spark = ss
- import spark.implicits._
- val xgbInput = Seq(
- (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
- (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
- .toDF("f1", "f2", "f3", "features", "label")
-
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
-
- val featuresName = Array("f1", "f2", "f3", "features")
- val xgbClassifier = new XGBoostRegressor(paramMap)
- .setFeaturesCol(featuresName)
- .setLabelCol("label")
-
- val model = xgbClassifier.fit(xgbInput)
- assert(model.getFeaturesCols.sameElements(featuresName))
-
- val df = model.transform(xgbInput)
- assert(df.schema.fieldNames.contains("features_" + model.uid))
- df.show()
-
- val newFeatureName = "features_new"
- // transform also can work for vectorized dataset
- val vectorizedInput = new VectorAssembler()
- .setInputCols(featuresName)
- .setOutputCol(newFeatureName)
- .transform(xgbInput)
- .select(newFeatureName, "label")
-
- val df1 = model
- .setFeaturesCol(newFeatureName)
- .transform(vectorizedInput)
- assert(df1.schema.fieldNames.contains(newFeatureName))
- df1.show()
- }
-
- test("featuresCols without features column can work") {
- val spark = ss
- import spark.implicits._
- val xgbInput = Seq(
- (Vectors.dense(1.0, 7.0), true, 10.1, 100.2, 0),
- (Vectors.dense(2.0, 20.0), false, 2.1, 2.2, 1))
- .toDF("f1", "f2", "f3", "f4", "label")
-
- val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> 1)
-
- val featuresName = Array("f1", "f2", "f3", "f4")
- val xgbClassifier = new XGBoostRegressor(paramMap)
- .setFeaturesCol(featuresName)
- .setLabelCol("label")
- .setEvalSets(Map("eval" -> xgbInput))
-
- val model = xgbClassifier.fit(xgbInput)
- assert(model.getFeaturesCols.sameElements(featuresName))
-
- // transform should work for the dataset which includes the feature column names.
- val df = model.transform(xgbInput)
- assert(df.schema.fieldNames.contains("features"))
- df.show()
-
- // transform also can work for vectorized dataset
- val vectorizedInput = new VectorAssembler()
- .setInputCols(featuresName)
- .setOutputCol("features")
- .transform(xgbInput)
- .select("features", "label")
-
- val df1 = model.transform(vectorizedInput)
- df1.show()
- }
-
- test("XGBoostRegressionModel should be compatible") {
- val trainingDF = buildDataFrame(Regression.train)
+ private def checkResultsWithXGBoost4j(
+ trainingDM: DMatrix,
+ testDM: DMatrix,
+ trainingDF: DataFrame,
+ testDF: DataFrame,
+ round: Int = 5,
+ xgbParams: Map[String, Any] = Map.empty,
+ weightCol: Option[String] = None): Unit = {
val paramMap = Map(
"eta" -> "1",
"max_depth" -> "6",
- "silent" -> "1",
- "objective" -> "reg:squarederror",
- "num_round" -> 5,
- "tree_method" -> treeMethod,
- "num_workers" -> numWorkers)
+ "base_score" -> 0.5,
+ "max_bin" -> 16) ++ xgbParams
+ val xgb4jModel = ScalaXGBoost.train(trainingDM, paramMap, round)
- val model = new XGBoostRegressor(paramMap).fit(trainingDF)
-
- val modelPath = new File(tempDir.toFile, "xgbc").getPath
- model.write.option("format", "json").save(modelPath)
- val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
- model.nativeBooster.saveModel(nativeJsonModelPath)
- assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
- nativeJsonModelPath))
-
- // test default "ubj"
- val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
- model.write.save(modelUbjPath)
-
- val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
- model.nativeBooster.saveModel(nativeUbjModelPath)
-
- assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
- nativeUbjModelPath))
-
- // test the deprecated format
- val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
- model.write.option("format", "deprecated").save(modelDeprecatedPath)
-
- val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
- model.nativeBooster.saveModel(nativeDeprecatedModelPath)
+ val regressor = new XGBoostRegressor(paramMap)
+ .setNumRound(round)
+ .setNumWorkers(numWorkers)
+ .setLeafPredictionCol("leaf")
+ .setContribPredictionCol("contrib")
+ weightCol.foreach(weight => regressor.setWeightCol(weight))
+
+ def checkEqual(left: Array[Array[Float]], right: Map[Int, Array[Float]]) = {
+ assert(left.size === right.size)
+ left.zipWithIndex.foreach { case (leftValue, index) =>
+ assert(leftValue.sameElements(right(index)))
+ }
+ }
- assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
- nativeDeprecatedModelPath))
+ val xgbSparkModel = regressor.fit(trainingDF)
+ val rows = xgbSparkModel.transform(testDF).collect()
+
+ // Check Leaf
+ val xgb4jLeaf = xgb4jModel.predictLeaf(testDM)
+ val xgbSparkLeaf = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("leaf").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jLeaf, xgbSparkLeaf)
+
+ // Check contrib
+ val xgb4jContrib = xgb4jModel.predictContrib(testDM)
+ val xgbSparkContrib = rows.map(row =>
+ (row.getAs[Int]("id"), row.getAs[DenseVector]("contrib").toArray.map(_.toFloat))).toMap
+ checkEqual(xgb4jContrib, xgbSparkContrib)
+
+ // Check prediction
+ val xgb4jPred = xgb4jModel.predict(testDM)
+ val xgbSparkPred = rows.map(row => {
+ val pred = row.getAs[Double]("prediction").toFloat
+ (row.getAs[Int]("id"), Array(pred))}).toMap
+ checkEqual(xgb4jPred, xgbSparkPred)
}
+
}
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
index 9622c9b2d44a..3a45cf4448c0 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostSuite.scala
@@ -1,5 +1,5 @@
/*
- Copyright (c) 2023 by Contributors
+ Copyright (c) 2023-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -16,40 +16,18 @@
package ml.dmlc.xgboost4j.scala.spark
-import ml.dmlc.xgboost4j.scala.Booster
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.scalatest.funsuite.AnyFunSuite
+import ml.dmlc.xgboost4j.scala.Booster
+
class XGBoostSuite extends AnyFunSuite with PerTest {
// Do not create spark context
override def beforeEach(): Unit = {}
- test("XGBoost execution parameters") {
- var xgbExecutionParams = new XGBoostExecutionParamsFactory(
- Map("device" -> "cpu", "num_workers" -> 1, "num_round" -> 1), sc)
- .buildXGBRuntimeParams
- assert(!xgbExecutionParams.runOnGpu)
-
- xgbExecutionParams = new XGBoostExecutionParamsFactory(
- Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
- .buildXGBRuntimeParams
- assert(xgbExecutionParams.runOnGpu)
-
- xgbExecutionParams = new XGBoostExecutionParamsFactory(
- Map("device" -> "cpu", "tree_method" -> "gpu_hist", "num_workers" -> 1, "num_round" -> 1), sc)
- .buildXGBRuntimeParams
- assert(xgbExecutionParams.runOnGpu)
-
- xgbExecutionParams = new XGBoostExecutionParamsFactory(
- Map("device" -> "cuda", "tree_method" -> "gpu_hist",
- "num_workers" -> 1, "num_round" -> 1), sc)
- .buildXGBRuntimeParams
- assert(xgbExecutionParams.runOnGpu)
- }
-
test("skip stage-level scheduling") {
val conf = new SparkConf()
.setMaster("spark://foo")
@@ -101,13 +79,13 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
}
- object FakedXGBoost extends XGBoostStageLevel {
+ object FakedXGBoost extends StageLevelScheduling {
// Do not skip stage-level scheduling for testing purposes.
override private[spark] def skipStageLevelScheduling(
- sparkVersion: String,
- runOnGpu: Boolean,
- conf: SparkConf) = false
+ sparkVersion: String,
+ runOnGpu: Boolean,
+ conf: SparkConf) = false
}
test("try stage-level scheduling without spark-rapids") {
@@ -129,12 +107,12 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
val df = ss.range(1, 10)
val rdd = df.rdd
- val xgbExecutionParams = new XGBoostExecutionParamsFactory(
- Map("device" -> "cuda", "num_workers" -> 1, "num_round" -> 1), sc)
- .buildXGBRuntimeParams
- assert(xgbExecutionParams.runOnGpu)
+ val runtimeParams = new XGBoostClassifier(
+ Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
+ .getRuntimeParameters(true)
+ assert(runtimeParams.runOnGpu)
- val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, xgbExecutionParams,
+ val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, runtimeParams,
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])
val taskResources = finalRDD.getResourceProfile().taskResources
diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml
index 345098327f5c..aa5b838fcca3 100644
--- a/jvm-packages/xgboost4j/pom.xml
+++ b/jvm-packages/xgboost4j/pom.xml
@@ -2,131 +2,132 @@
+ * The margin must have the same number of elements as the number of
+ * rows in this matrix.
+ */
+ public void setBaseMargin(float[] baseMargin) throws XGBoostError {
+ if (baseMargin.length != rowNum()) {
+ throw new IllegalArgumentException(String.format(
+ "base margin must have exactly %s elements, got %s",
+ rowNum(), baseMargin.length));
+ }
+
+ XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
+ }
+
+ /**
+ * Set base margin (initial prediction).
+ */
+ public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
+ setBaseMargin(flatten(baseMargin));
+ }
+
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*
@@ -448,22 +500,6 @@ public long getHandle() {
return handle;
}
- /**
- * flatten a mat to array
- */
- private static float[] flatten(float[][] mat) {
- int size = 0;
- for (float[] array : mat) size += array.length;
- float[] result = new float[size];
- int pos = 0;
- for (float[] ar : mat) {
- System.arraycopy(ar, 0, result, pos, ar.length);
- pos += ar.length;
- }
-
- return result;
- }
-
@Override
protected void finalize() {
dispose();
@@ -475,4 +511,12 @@ public synchronized void dispose() {
handle = 0;
}
}
+
+ /**
+ * sparse matrix type (CSR or CSC)
+ */
+ public enum SparseType {
+ CSR,
+ CSC
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
index 48b163a7753b..3fe787be2f7e 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/RabitTracker.java
@@ -89,7 +89,7 @@ public boolean start() throws XGBoostError {
this.trackerDaemon = new Thread(() -> {
try {
waitFor(0);
- } catch (XGBoostError ex) {
+ } catch (Exception ex) {
logger.error(ex);
return; // exit the thread
}
diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
index b410d2be1d02..00413636e0f0 100644
--- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
+++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java
@@ -54,7 +54,7 @@ static void checkCall(int ret) throws XGBoostError {
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
final static native int XGDMatrixCreateFromDataIter(java.util.Iterator