Skip to content

Commit

Permalink
test the group col which should be sorted in each partition
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 committed Sep 14, 2024
1 parent ece0b9b commit eeca573
Showing 1 changed file with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.ml.linalg.{DenseVector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.scalatest.funsuite.AnyFunSuite

import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
Expand Down Expand Up @@ -131,6 +131,26 @@ class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite
)
}

test("The group col should be sorted in each partition") {
val trainingDF = buildDataFrameWithGroup(Ranking.train)

val ranker = new XGBoostRanker()
.setNumRound(1)
.setNumWorkers(numWorkers)
.setGroupCol("group")

val (df, _) = ranker.preprocess(trainingDF)
df.rdd.foreachPartition { iter => {
var prevGroup = Int.MinValue
while (iter.hasNext) {
val curr = iter.next()
val group = curr.asInstanceOf[Row].getAs[Int](2)
assert(prevGroup <= group)
prevGroup = group
}
}}
}

private def runLengthEncode(input: Seq[Int]): Seq[Int] = {
if (input.isEmpty) return Seq(0)

Expand Down

0 comments on commit eeca573

Please sign in to comment.