Skip to content

Commit

Permalink
unify the checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 17, 2023
1 parent d11bfc0 commit c3c21f3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ object GpuPreXGBoost extends PreXGBoostProvider {
estimator match {
case est: XGBoostEstimatorCommon =>
require(
est.isDefined(est.device) && est.getDevice.equals("cuda") ||
est.isDefined(est.device) &&
(est.getDevice.equals("cuda") || est.getDevice.equals("gpu")) ||
est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"),
s"GPU train requires `device` set to `cuda`"
s"GPU train requires `device` set to `cuda` or `gpu`."
)
val groupName = estimator match {
case regressor: XGBoostRegressor => if (regressor.isDefined(regressor.groupCol)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2021-2022 by Contributors
Copyright (c) 2021-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -150,7 +150,6 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
}

test("Train with eval") {

withGpuSparkSession() { spark =>
import spark.implicits._
val Array(trainingDf, eval1, eval2) = trainingData.toDF(allColumnNames: _*)
Expand Down Expand Up @@ -190,4 +189,24 @@ class GpuXGBoostGeneralSuite extends GpuTestSuite {
}
}

test("device ordinal should not be specified") {
withGpuSparkSession() { spark =>
import spark.implicits._
val trainingDf = trainingData.toDF(allColumnNames: _*)
val params = Map(
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 5,
"num_workers" -> 1
)
val thrown = intercept[IllegalArgumentException] {
new XGBoostClassifier(params)
.setFeaturesCol(featureNames)
.setLabelCol(labelName)
.setDevice("cuda:1")
.fit(trainingDf)
}
assert(thrown.getMessage.contains("`cuda` or `gpu`"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
case None => None
case Some(dev: String) => if (treeMethod == "gpu_hist") Some("cuda") else Some(dev)
}
if (!device.isEmpty) {
require(
!device.contains(":"),
"Please don't specify the device ordinal as GPUs are managed by Spark."
)
}

if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class XGBoostClassifier (

def setTreeMethod(value: String): this.type = set(treeMethod, value)

def setDevice(value: String): this.type = set(device, value)

def setGrowPolicy(value: String): this.type = set(growPolicy, value)

def setMaxBins(value: Int): this.type = set(maxBins, value)
Expand Down

0 comments on commit c3c21f3

Please sign in to comment.