Skip to content

Commit

Permalink
Classifier.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 4, 2024
1 parent 6ff9452 commit 9061201
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -432,28 +432,29 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)

// test json
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
// test ubj
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
nativeUbjModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath1)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
nativeUbjModelPath1))
}

test("native json model file should store feature_name and feature_type") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 9061201

Please sign in to comment.