diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala index 7d42995e4ff6..81a770bfe327 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.ml.linalg.{DenseVector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.scalatest.funsuite.AnyFunSuite import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} @@ -131,6 +131,26 @@ class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite ) } + test("The group col should be sorted in each partition") { + val trainingDF = buildDataFrameWithGroup(Ranking.train) + + val ranker = new XGBoostRanker() + .setNumRound(1) + .setNumWorkers(numWorkers) + .setGroupCol("group") + + val (df, _) = ranker.preprocess(trainingDF) + df.rdd.foreachPartition { iter => { + var prevGroup = Int.MinValue + while (iter.hasNext) { + val curr = iter.next() + val group = curr.asInstanceOf[Row].getAs[Int](2) + assert(prevGroup <= group) + prevGroup = group + } + }} + } + private def runLengthEncode(input: Seq[Int]): Seq[Int] = { if (input.isEmpty) return Seq(0)