Skip to content

Commit 38dd91f

Browse files
authored
Save model in ubj as the default. (#9947)
1 parent c03a4d5 commit 38dd91f

File tree

23 files changed

+600
-552
lines changed

23 files changed

+600
-552
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala

-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ import org.apache.spark.ml.param.Params
3030
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3131

3232
abstract class XGBoostWriter extends MLWriter {
33-
34-
/** Currently it's using the "deprecated" format as
35-
* default, which will be changed into `ubj` in future releases. */
3633
def getModelFormat(): String = {
3734
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
3835
}

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2014-2022 by Contributors
2+
Copyright (c) 2014-2024 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -432,28 +432,29 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
432432
val xgb = new XGBoostClassifier(paramMap)
433433
val model = xgb.fit(trainingDF)
434434

435+
// test json
435436
val modelPath = new File(tempDir.toFile, "xgbc").getPath
436437
model.write.option("format", "json").save(modelPath)
437438
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
438439
model.nativeBooster.saveModel(nativeJsonModelPath)
439440
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
440441
nativeJsonModelPath))
441442

442-
// test default "deprecated"
443+
// test ubj
443444
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
444445
model.write.save(modelUbjPath)
445-
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
446-
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
446+
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
447+
model.nativeBooster.saveModel(nativeUbjModelPath)
447448
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
448-
nativeDeprecatedModelPath))
449+
nativeUbjModelPath))
449450

450451
// json file should be indifferent with ubj file
451452
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
452453
model.write.option("format", "json").save(modelJsonPath)
453-
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
454-
model.nativeBooster.saveModel(nativeUbjModelPath)
454+
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
455+
model.nativeBooster.saveModel(nativeUbjModelPath1)
455456
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
456-
nativeUbjModelPath))
457+
nativeUbjModelPath1))
457458
}
458459

459460
test("native json model file should store feature_name and feature_type") {

jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala

+15-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2014-2022 by Contributors
2+
Copyright (c) 2014-2024 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -333,21 +333,24 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
333333
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
334334
nativeJsonModelPath))
335335

336-
// test default "deprecated"
336+
// test default "ubj"
337337
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
338338
model.write.save(modelUbjPath)
339-
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
340-
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
341-
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
342-
nativeDeprecatedModelPath))
343339

344-
// json file should be indifferent with ubj file
345-
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
346-
model.write.option("format", "json").save(modelJsonPath)
347-
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
340+
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
348341
model.nativeBooster.saveModel(nativeUbjModelPath)
349-
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,
342+
343+
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
350344
nativeUbjModelPath))
351-
}
352345

346+
// test the deprecated format
347+
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
348+
model.write.option("format", "deprecated").save(modelDeprecatedPath)
349+
350+
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
351+
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
352+
353+
assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
354+
nativeDeprecatedModelPath))
355+
}
353356
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
3535
*/
3636
public class Booster implements Serializable, KryoSerializable {
37-
public static final String DEFAULT_FORMAT = "deprecated";
37+
public static final String DEFAULT_FORMAT = "ubj";
3838
private static final Log logger = LogFactory.getLog(Booster.class);
3939
// handle to the booster.
4040
private long handle = 0;
@@ -788,8 +788,7 @@ private Map<String, Double> getFeatureImportanceFromModel(
788788
}
789789

790790
/**
791-
* Save model into raw byte array. Currently it's using the deprecated format as
792-
* default, which will be changed into `ubj` in future releases.
791+
* Save model into raw byte array in the UBJSON ("ubj") format.
793792
*
794793
* @return the saved byte array
795794
* @throws XGBoostError native error

jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
337337
}
338338

339339
/**
340-
* Save model into a raw byte array. Currently it's using the deprecated format as
341-
* default, which will be changed into `ubj` in future releases.
340+
* Save model into a raw byte array in the UBJSON ("ubj") format.
342341
*/
343342
@throws(classOf[XGBoostError])
344343
def toByteArray: Array[Byte] = {

python-package/xgboost/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2613,7 +2613,7 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None:
26132613
else:
26142614
raise TypeError("fname must be a string or os PathLike")
26152615

2616-
def save_raw(self, raw_format: str = "deprecated") -> bytearray:
2616+
def save_raw(self, raw_format: str = "ubj") -> bytearray:
26172617
"""Save the model to a in memory buffer representation instead of file.
26182618
26192619
Parameters

python-package/xgboost/testing/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def random_csc(t_id: int) -> sparse.csc_matrix:
630630

631631
def make_datasets_with_margin(
632632
unweighted_strategy: strategies.SearchStrategy,
633-
) -> Callable:
633+
) -> Callable[[], strategies.SearchStrategy[TestDataset]]:
634634
"""Factory function for creating strategies that generates datasets with weight and
635635
base margin.
636636
@@ -668,8 +668,7 @@ def weight_margin(draw: Callable) -> TestDataset:
668668

669669
# A strategy for drawing from a set of example datasets. May add random weights to the
670670
# dataset
671-
@memory.cache
672-
def make_dataset_strategy() -> Callable:
671+
def make_dataset_strategy() -> strategies.SearchStrategy[TestDataset]:
673672
_unweighted_datasets_strategy = strategies.sampled_from(
674673
[
675674
TestDataset(

src/c_api/c_api.cc

+7-9
Original file line numberDiff line numberDiff line change
@@ -1313,10 +1313,8 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
13131313

13141314
namespace {
13151315
void WarnOldModel() {
1316-
if (XGBOOST_VER_MAJOR >= 2) {
1317-
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
1318-
"`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.";
1319-
}
1316+
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
1317+
"`ubj`. Model format is default to UBJSON in XGBoost 2.1 if not specified.";
13201318
}
13211319
} // anonymous namespace
13221320

@@ -1339,14 +1337,14 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
13391337
save_json(std::ios::out);
13401338
} else if (common::FileExtension(fname) == "ubj") {
13411339
save_json(std::ios::binary);
1342-
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
1343-
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
1344-
"`deprecated` to choose between formats.";
1345-
save_json(std::ios::out);
1346-
} else {
1340+
} else if (common::FileExtension(fname) == "deprecated") {
13471341
WarnOldModel();
13481342
auto *bst = static_cast<Learner *>(handle);
13491343
bst->SaveModel(fo.get());
1344+
} else {
1345+
LOG(WARNING) << "Saving model in the UBJSON format as default. You can use file extension:"
1346+
" `json`, `ubj` or `deprecated` to choose between formats.";
1347+
save_json(std::ios::binary);
13501348
}
13511349
API_END();
13521350
}

tests/ci_build/lint_python.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class LintersPaths:
2727
"tests/python/test_quantile_dmatrix.py",
2828
"tests/python/test_tree_regularization.py",
2929
"tests/python/test_shap.py",
30+
"tests/python/test_model_io.py",
3031
"tests/python/test_with_pandas.py",
3132
"tests/python-gpu/",
3233
"tests/python-sycl/",
@@ -83,6 +84,7 @@ class LintersPaths:
8384
"tests/python/test_multi_target.py",
8485
"tests/python-gpu/test_gpu_data_iterator.py",
8586
"tests/python-gpu/load_pickle.py",
87+
"tests/python/test_model_io.py",
8688
"tests/test_distributed/test_with_spark/test_data.py",
8789
"tests/test_distributed/test_gpu_with_spark/test_data.py",
8890
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",

0 commit comments

Comments
 (0)