diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala index cb973ce0217a..275263a34ef5 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala @@ -113,7 +113,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin { */ override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( estimator: XGBoostEstimator[T, M], - dataset: Dataset[_]): RDD[() => Watches] = { + dataset: Dataset[_]): RDD[Watches] = { validate(estimator, dataset) @@ -148,19 +148,25 @@ class GpuXGBoostPlugin extends XGBoostPlugin { val evalProcessed = preprocess(estimator, evalDs) ColumnarRdd(train.toDF()).zipPartitions(ColumnarRdd(evalProcessed.toDF())) { (trainIter, evalIter) => - Iterator.single(() => { - val trainDM = buildQuantileDMatrix(trainIter) - val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM)) - new Watches(Array(trainDM, evalDM), - Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) - }) + new Iterator[Watches] { + override def hasNext: Boolean = trainIter.hasNext + override def next(): Watches = { + val trainDM = buildQuantileDMatrix(trainIter) + val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM)) + new Watches(Array(trainDM, evalDM), + Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) + } + } } }.getOrElse( ColumnarRdd(train.toDF()).mapPartitions { iter => - Iterator.single(() => { - val dm = buildQuantileDMatrix(iter) - new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) - }) + new Iterator[Watches] { + override def hasNext: Boolean = iter.hasNext + override def next(): Watches = { + val dm = buildQuantileDMatrix(iter) + new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) + } + } } ) } diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala index 068ffd54186c..97f54b601eb3 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala @@ -205,7 +205,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df) val result = rdd.mapPartitions { iter => - val watches = iter.next()() + val watches = iter.next() val size = watches.size val labels = watches.datasets(0).getLabel val weight = watches.datasets(0).getWeight @@ -269,7 +269,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train) val result = rdd.mapPartitions { iter => - val watches = iter.next()() + val watches = iter.next() val size = watches.size val labels = watches.datasets(1).getLabel val weight = watches.datasets(1).getWeight diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 56ac9c5429f2..b4ef1509ca00 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -226,7 +226,7 @@ private[spark] object XGBoost extends StageLevelScheduling { * @param xgboostParams the xgboost parameters to pass to xgboost library * @return the booster and the metrics */ - def train(input: RDD[() => Watches], + def train(input: RDD[Watches], runtimeParams: RuntimeParams, xgboostParams: Map[String, Any]): (Booster, Map[String, Array[Float]]) = { @@ -249,7 +249,7 @@ private[spark] object XGBoost extends StageLevelScheduling { try { Communicator.init(rabitEnv) require(iter.hasNext, "Failed to create DMatrix") - val watches = iter.next()() + val watches = iter.next() try { val (booster, metrics) = trainBooster(watches, runtimeParams, xgboostParams) if (partitionId == 0) { diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala index 6a520e1220cb..27e8cb0b4aa4 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimator.scala @@ -234,10 +234,10 @@ private[spark] trait XGBoostEstimator[ * * @param dataset * @param columnsOrder the order of columns including weight/group/base margin ... - * @return RDD + * @return RDD[Watches] */ private[spark] def toRdd(dataset: Dataset[_], - columnIndices: ColumnIndices): RDD[() => Watches] = { + columnIndices: ColumnIndices): RDD[Watches] = { val trainRDD = toXGBLabeledPoint(dataset, columnIndices) val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames) @@ -309,20 +309,25 @@ private[spark] trait XGBoostEstimator[ val (evalDf, _) = preprocess(eval) val evalRDD = toXGBLabeledPoint(evalDf, columnIndices) trainRDD.zipPartitions(evalRDD) { (left, right) => - Iterator.single(() => { - val trainDMatrix = buildDMatrix(left) - val evalDMatrix = buildDMatrix(right) - new Watches(Array(trainDMatrix, evalDMatrix), - Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) - }) + new Iterator[Watches] { + override def hasNext: Boolean = left.hasNext + override def next(): Watches = { + val trainDMatrix = buildDMatrix(left) + val evalDMatrix = buildDMatrix(right) + new Watches(Array(trainDMatrix, evalDMatrix), + Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None) + } + } } }.getOrElse( trainRDD.mapPartitions { iter => - - Iterator.single(() => { - val dm = buildDMatrix(iter) - new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) - }) + new Iterator[Watches] { + override def hasNext: Boolean = iter.hasNext + override def next(): Watches = { + val dm = buildDMatrix(iter) + new Watches(Array(dm), Array(Utils.TRAIN_NAME), None) + } + } } ) } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala index 93587ab647e5..dda82f97968b 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostPlugin.scala @@ -39,7 +39,7 @@ trait XGBoostPlugin extends Serializable { */ def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]]( estimator: XGBoostEstimator[T, M], - dataset: Dataset[_]): RDD[() => Watches] + dataset: Dataset[_]): RDD[Watches] /** * Transform the dataset diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala index 43550f974678..8895789bac0d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostEstimatorSuite.scala @@ -334,7 +334,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu val rdd = classifier.toRdd(df, indices) val result = rdd.mapPartitions { iter => if (iter.hasNext) { - val watches = iter.next()() + val watches = iter.next() val size = watches.size val trainDM = watches.toMap(TRAIN_NAME) val rowNum = trainDM.rowNum @@ -410,7 +410,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu val rdd = classifier.toRdd(df, indices) val result = rdd.mapPartitions { iter => if (iter.hasNext) { - val watches = iter.next()() + val watches = iter.next() val size = watches.size val evalDM = watches.toMap(Utils.VALIDATION_NAME) val rowNum = evalDM.rowNum