Skip to content

Commit

Permalink
return Iterator to replace function deplaying to build quantile dmatr…
Browse files Browse the repository at this point in the history
…ix (#27)
  • Loading branch information
wbo4958 authored Sep 10, 2024
1 parent 44ac4d6 commit 0feb48f
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
*/
override def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): RDD[() => Watches] = {
dataset: Dataset[_]): RDD[Watches] = {

validate(estimator, dataset)

Expand Down Expand Up @@ -148,19 +148,25 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
val evalProcessed = preprocess(estimator, evalDs)
ColumnarRdd(train.toDF()).zipPartitions(ColumnarRdd(evalProcessed.toDF())) {
(trainIter, evalIter) =>
Iterator.single(() => {
val trainDM = buildQuantileDMatrix(trainIter)
val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM))
new Watches(Array(trainDM, evalDM),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
})
new Iterator[Watches] {
override def hasNext: Boolean = trainIter.hasNext
override def next(): Watches = {
val trainDM = buildQuantileDMatrix(trainIter)
val evalDM = buildQuantileDMatrix(evalIter, Some(trainDM))
new Watches(Array(trainDM, evalDM),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
}
}
}
}.getOrElse(
ColumnarRdd(train.toDF()).mapPartitions { iter =>
Iterator.single(() => {
val dm = buildQuantileDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
})
new Iterator[Watches] {
override def hasNext: Boolean = iter.hasNext
override def next(): Watches = {
val dm = buildQuantileDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
}
}
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {

val rdd = classifier.getPlugin.get.buildRddWatches(classifier, df)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()()
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(0).getLabel
val weight = watches.datasets(0).getWeight
Expand Down Expand Up @@ -269,7 +269,7 @@ class GpuXGBoostPluginSuite extends GpuTestSuite {

val rdd = classifier.getPlugin.get.buildRddWatches(classifier, train)
val result = rdd.mapPartitions { iter =>
val watches = iter.next()()
val watches = iter.next()
val size = watches.size
val labels = watches.datasets(1).getLabel
val weight = watches.datasets(1).getWeight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private[spark] object XGBoost extends StageLevelScheduling {
* @param xgboostParams the xgboost parameters to pass to xgboost library
* @return the booster and the metrics
*/
def train(input: RDD[() => Watches],
def train(input: RDD[Watches],
runtimeParams: RuntimeParams,
xgboostParams: Map[String, Any]): (Booster, Map[String, Array[Float]]) = {

Expand All @@ -249,7 +249,7 @@ private[spark] object XGBoost extends StageLevelScheduling {
try {
Communicator.init(rabitEnv)
require(iter.hasNext, "Failed to create DMatrix")
val watches = iter.next()()
val watches = iter.next()
try {
val (booster, metrics) = trainBooster(watches, runtimeParams, xgboostParams)
if (partitionId == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ private[spark] trait XGBoostEstimator[
*
* @param dataset
* @param columnsOrder the order of columns including weight/group/base margin ...
* @return RDD
* @return RDD[Watches]
*/
private[spark] def toRdd(dataset: Dataset[_],
columnIndices: ColumnIndices): RDD[() => Watches] = {
columnIndices: ColumnIndices): RDD[Watches] = {
val trainRDD = toXGBLabeledPoint(dataset, columnIndices)

val featureNames = if (getFeatureNames.isEmpty) None else Some(getFeatureNames)
Expand Down Expand Up @@ -309,20 +309,25 @@ private[spark] trait XGBoostEstimator[
val (evalDf, _) = preprocess(eval)
val evalRDD = toXGBLabeledPoint(evalDf, columnIndices)
trainRDD.zipPartitions(evalRDD) { (left, right) =>
Iterator.single(() => {
val trainDMatrix = buildDMatrix(left)
val evalDMatrix = buildDMatrix(right)
new Watches(Array(trainDMatrix, evalDMatrix),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
})
new Iterator[Watches] {
override def hasNext: Boolean = left.hasNext
override def next(): Watches = {
val trainDMatrix = buildDMatrix(left)
val evalDMatrix = buildDMatrix(right)
new Watches(Array(trainDMatrix, evalDMatrix),
Array(Utils.TRAIN_NAME, Utils.VALIDATION_NAME), None)
}
}
}
}.getOrElse(
trainRDD.mapPartitions { iter =>

Iterator.single(() => {
val dm = buildDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
})
new Iterator[Watches] {
override def hasNext: Boolean = iter.hasNext
override def next(): Watches = {
val dm = buildDMatrix(iter)
new Watches(Array(dm), Array(Utils.TRAIN_NAME), None)
}
}
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait XGBoostPlugin extends Serializable {
*/
def buildRddWatches[T <: XGBoostEstimator[T, M], M <: XGBoostModel[M]](
estimator: XGBoostEstimator[T, M],
dataset: Dataset[_]): RDD[() => Watches]
dataset: Dataset[_]): RDD[Watches]

/**
* Transform the dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
val rdd = classifier.toRdd(df, indices)
val result = rdd.mapPartitions { iter =>
if (iter.hasNext) {
val watches = iter.next()()
val watches = iter.next()
val size = watches.size
val trainDM = watches.toMap(TRAIN_NAME)
val rowNum = trainDM.rowNum
Expand Down Expand Up @@ -410,7 +410,7 @@ class XGBoostEstimatorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
val rdd = classifier.toRdd(df, indices)
val result = rdd.mapPartitions { iter =>
if (iter.hasNext) {
val watches = iter.next()()
val watches = iter.next()
val size = watches.size
val evalDM = watches.toMap(Utils.VALIDATION_NAME)
val rowNum = evalDM.rowNum
Expand Down

0 comments on commit 0feb48f

Please sign in to comment.