Skip to content

Commit 69268df

Browse files
authored
[EM] Drop support for the device cache. (#11446)
1 parent 2fad970 commit 69268df

File tree

20 files changed

+58
-270
lines changed

20 files changed

+58
-270
lines changed

doc/tutorials/external_memory.rst

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,8 @@ stages the cache on CPU memory by default. Users can change the backing storage
147147
specifying the ``on_host`` parameter in the :py:class:`~xgboost.DataIter`. However, using
148148
the disk is not recommended as it's likely to make the GPU slower than the CPU. The option
149149
is here for experimentation purposes only. In addition,
150-
:py:class:`~xgboost.ExtMemQuantileDMatrix` parameters ``max_num_device_pages``,
151-
``min_cache_page_bytes``, and ``max_quantile_batches`` can help control the data placement
152-
and memory usage.
150+
:py:class:`~xgboost.ExtMemQuantileDMatrix` parameters ``min_cache_page_bytes``, and
151+
``max_quantile_batches`` can help control the data placement and memory usage.
153152

154153
Inputs to the :py:class:`~xgboost.ExtMemQuantileDMatrix` (through the iterator) must be on
155154
the GPU. Following is a snippet from :ref:`sphx_glr_python_examples_external_memory.py`:
@@ -194,9 +193,9 @@ memory. XGBoost relies on the asynchronous memory pool to reduce the overhead of
194193
fetching. In addition, the open source `NVIDIA Linux driver
195194
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
196195
is required for ``Heterogeneous memory management (HMM)`` support. Usually, users need not
197-
to change :py:class:`~xgboost.ExtMemQuantileDMatrix` parameters ``max_num_device_pages``
198-
and ``min_cache_page_bytes``, they are automatically configured based on the device and
199-
don't change model accuracy. However, the ``max_quantile_batches`` can be useful if
196+
to change :py:class:`~xgboost.ExtMemQuantileDMatrix` parameters like
197+
``min_cache_page_bytes``, they are automatically configured based on the device and don't
198+
change model accuracy. However, the ``max_quantile_batches`` can be useful if
200199
:py:class:`~xgboost.ExtMemQuantileDMatrix` is running out of device memory during
201200
construction, see :py:class:`~xgboost.QuantileDMatrix` and the following sections for more
202201
info. Currently, we focus on devices with ``NVLink-C2C`` support for GPU-based external

include/xgboost/data.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,19 +538,16 @@ struct ExtMemConfig {
538538
std::int64_t min_cache_page_bytes{0};
539539
// Missing value.
540540
float missing{std::numeric_limits<float>::quiet_NaN()};
541-
// Maximum number of pages cached in device.
542-
std::int64_t max_num_device_pages{0};
543541
// The number of CPU threads.
544542
std::int32_t n_threads{0};
545543

546544
ExtMemConfig() = default;
547545
ExtMemConfig(std::string cache, bool on_host, std::int64_t min_cache, float missing,
548-
std::int64_t max_num_d, std::int32_t n_threads)
546+
std::int32_t n_threads)
549547
: cache{std::move(cache)},
550548
on_host{on_host},
551549
min_cache_page_bytes{min_cache},
552550
missing{missing},
553-
max_num_device_pages{max_num_d},
554551
n_threads{n_threads} {}
555552
};
556553

jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/ExtMemQuantileDMatrix.java

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
3030
int maxBin,
3131
DMatrix ref,
3232
int nthread,
33-
int maxNumDevicePages,
3433
int maxQuantileBatches,
3534
int minCachePageBytes) throws XGBoostError {
3635
long[] out = new long[1];
@@ -39,8 +38,8 @@ public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
3938
refHandle = new long[1];
4039
refHandle[0] = ref.getHandle();
4140
}
42-
String conf = this.getConfig(missing, maxBin, nthread, maxNumDevicePages,
43-
maxQuantileBatches, minCachePageBytes);
41+
String conf = this.getConfig(missing, maxBin, nthread,
42+
maxQuantileBatches, minCachePageBytes);
4443
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
4544
iter, refHandle, conf, out));
4645
handle = out[0];
@@ -51,7 +50,7 @@ public ExtMemQuantileDMatrix(
5150
float missing,
5251
int maxBin,
5352
DMatrix ref) throws XGBoostError {
54-
this(iter, missing, maxBin, ref, 0, -1, -1, -1);
53+
this(iter, missing, maxBin, ref, 0, -1, -1);
5554
}
5655

5756
public ExtMemQuantileDMatrix(
@@ -61,16 +60,13 @@ public ExtMemQuantileDMatrix(
6160
this(iter, missing, maxBin, null);
6261
}
6362

64-
private String getConfig(float missing, int maxBin, int nthread, int maxNumDevicePages,
65-
int maxQuantileBatches, int minCachePageBytes) {
63+
private String getConfig(float missing, int maxBin, int nthread,
64+
int maxQuantileBatches, int minCachePageBytes) {
6665
Map<String, Object> conf = new java.util.HashMap<>();
6766
conf.put("missing", missing);
6867
conf.put("max_bin", maxBin);
6968
conf.put("nthread", nthread);
7069

71-
if (maxNumDevicePages > 0) {
72-
conf.put("max_num_device_pages", maxNumDevicePages);
73-
}
7470
if (maxQuantileBatches > 0) {
7571
conf.put("max_quantile_batches", maxQuantileBatches);
7672
}

jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/ExtMemQuantileDMatrix.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@ class ExtMemQuantileDMatrix private[scala](
2727
maxBin: Int,
2828
ref: Option[QuantileDMatrix],
2929
nthread: Int,
30-
maxNumDevicePages: Int,
3130
maxQuantileBatches: Int,
3231
minCachePageBytes: Int) {
3332
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin,
3433
ref.map(_.jDMatrix).orNull,
35-
nthread, maxNumDevicePages, maxQuantileBatches, minCachePageBytes))
34+
nthread, maxQuantileBatches, minCachePageBytes))
3635
}
3736

3837
def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int) {

jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPlugin.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
134134

135135
val maxQuantileBatches = estimator.getMaxQuantileBatches
136136
val minCachePageBytes = estimator.getMinCachePageBytes
137-
val maxNumDevicePages = estimator.getMaxNumDevicePages
138137

139138
/** build QuantileDMatrix on the executor side */
140139
def buildQuantileDMatrix(input: Iterator[Table],
@@ -143,7 +142,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
143142
extMemPath match {
144143
case Some(_) =>
145144
val itr = new ExternalMemoryIterator(input, indices, extMemPath)
146-
new ExtMemQuantileDMatrix(itr, missing, maxBin, ref, nthread, maxNumDevicePages,
145+
new ExtMemQuantileDMatrix(itr, missing, maxBin, ref, nthread,
147146
maxQuantileBatches, minCachePageBytes)
148147

149148
case None =>

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/XGBoostParams.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,6 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
188188

189189
final def getUseExternalMemory: Boolean = $(useExternalMemory)
190190

191-
final val maxNumDevicePages = new IntParam(this, "maxNumDevicePages", "Maximum number of " +
192-
"pages cached in device")
193-
194-
final def getMaxNumDevicePages: Int = $(maxNumDevicePages)
195-
196191
final val maxQuantileBatches = new IntParam(this, "maxQuantileBatches", "Maximum quantile " +
197192
"batches")
198193

@@ -207,7 +202,7 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
207202
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
208203
featuresCols -> Array.empty, customObj -> null, customEval -> null,
209204
featureNames -> Array.empty, featureTypes -> Array.empty, useExternalMemory -> false,
210-
maxNumDevicePages -> -1, maxQuantileBatches -> -1, minCachePageBytes -> -1)
205+
maxQuantileBatches -> -1, minCachePageBytes -> -1)
211206

212207
addNonXGBoostParam(numWorkers, numRound, numEarlyStoppingRounds, inferBatchSize, featuresCol,
213208
labelCol, baseMarginCol, weightCol, predictionCol, leafPredictionCol, contribPredictionCol,
@@ -251,8 +246,6 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe
251246

252247
def setUseExternalMemory(value: Boolean): T = set(useExternalMemory, value).asInstanceOf[T]
253248

254-
def setMaxNumDevicePages(value: Int): T = set(maxNumDevicePages, value).asInstanceOf[T]
255-
256249
def setMaxQuantileBatches(value: Int): T = set(maxQuantileBatches, value).asInstanceOf[T]
257250

258251
def setMinCachePageBytes(value: Int): T = set(minCachePageBytes, value).asInstanceOf[T]

python-package/xgboost/core.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,7 +1820,6 @@ def __init__( # pylint: disable=super-init-not-called
18201820
max_bin: Optional[int] = None,
18211821
ref: Optional[DMatrix] = None,
18221822
enable_categorical: bool = False,
1823-
max_num_device_pages: Optional[int] = None,
18241823
max_quantile_batches: Optional[int] = None,
18251824
) -> None:
18261825
"""
@@ -1829,15 +1828,6 @@ def __init__( # pylint: disable=super-init-not-called
18291828
data :
18301829
A user-defined :py:class:`DataIter` for loading data.
18311830
1832-
max_num_device_pages :
1833-
For a GPU-based validation dataset, XGBoost can optionally cache some pages
1834-
in device memory instead of host memory to reduce data transfer. Each cached
1835-
page has size of `min_cache_page_bytes`. Set this to 0 if you don't want
1836-
pages to be cached in the device memory. This can be useful for preventing
1837-
OOM error where there are more than one validation datasets. The default
1838-
number of device-based page is 1. Lastly, XGBoost infers whether a dataset
1839-
is used for valdiation by checking whether ref is not None.
1840-
18411831
max_quantile_batches :
18421832
See :py:class:`QuantileDMatrix`.
18431833
@@ -1850,7 +1840,6 @@ def __init__( # pylint: disable=super-init-not-called
18501840
data,
18511841
ref,
18521842
enable_categorical=enable_categorical,
1853-
max_num_device_pages=max_num_device_pages,
18541843
max_quantile_blocks=max_quantile_batches,
18551844
)
18561845
assert self.handle is not None
@@ -1861,7 +1850,6 @@ def _init(
18611850
ref: Optional[DMatrix],
18621851
*,
18631852
enable_categorical: bool,
1864-
max_num_device_pages: Optional[int] = None,
18651853
max_quantile_blocks: Optional[int] = None,
18661854
) -> None:
18671855
args = make_jcargs(
@@ -1871,7 +1859,6 @@ def _init(
18711859
on_host=it.on_host,
18721860
max_bin=self.max_bin,
18731861
min_cache_page_bytes=it.min_cache_page_bytes,
1874-
max_num_device_pages=max_num_device_pages,
18751862
# It's called blocks internally due to block-based quantile sketching.
18761863
max_quantile_blocks=max_quantile_blocks,
18771864
)
@@ -2559,9 +2546,9 @@ def predict(
25592546
prediction. Note the final column is the bias term.
25602547
25612548
approx_contribs :
2562-
Approximate the contributions of each feature. Used when ``pred_contribs`` or
2563-
``pred_interactions`` is set to True. Changing the default of this parameter
2564-
(False) is not recommended.
2549+
Approximate the contributions of each feature. Used when ``pred_contribs``
2550+
or ``pred_interactions`` is set to True. Changing the default of this
2551+
parameter (False) is not recommended.
25652552
25662553
pred_interactions :
25672554
When this is True the output will be a matrix of size (nsample,
@@ -2579,10 +2566,10 @@ def predict(
25792566
25802567
training :
25812568
Whether the prediction value is used for training. This can effect `dart`
2582-
booster, which performs dropouts during training iterations but use all trees
2583-
for inference. If you want to obtain result with dropouts, set this parameter
2584-
to `True`. Also, the parameter is set to true when obtaining prediction for
2585-
custom objective function.
2569+
booster, which performs dropouts during training iterations but use all
2570+
trees for inference. If you want to obtain result with dropouts, set this
2571+
parameter to `True`. Also, the parameter is set to true when obtaining
2572+
prediction for custom objective function.
25862573
25872574
.. versionadded:: 1.0.0
25882575
@@ -2595,8 +2582,8 @@ def predict(
25952582
.. versionadded:: 1.4.0
25962583
25972584
strict_shape :
2598-
When set to True, output shape is invariant to whether classification is used.
2599-
For both value and margin prediction, the output shape is (n_samples,
2585+
When set to True, output shape is invariant to whether classification is
2586+
used. For both value and margin prediction, the output shape is (n_samples,
26002587
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
26012588
which case the output shape can be (n_samples, ) if multi-class is not used.
26022589
@@ -3116,8 +3103,8 @@ def get_fscore(self, fmap: PathLike = "") -> Dict[str, Union[float, List[float]]
31163103
31173104
.. note:: Zero-importance features will not be included
31183105
3119-
Keep in mind that this function does not include zero-importance feature, i.e.
3120-
those features that have not been used in any split conditions.
3106+
Keep in mind that this function does not include zero-importance feature,
3107+
i.e. those features that have not been used in any split conditions.
31213108
31223109
Parameters
31233110
----------
@@ -3141,13 +3128,13 @@ def get_score(
31413128
31423129
.. note::
31433130
3144-
For linear model, only "weight" is defined and it's the normalized coefficients
3145-
without bias.
3131+
For linear model, only "weight" is defined and it's the normalized
3132+
coefficients without bias.
31463133
31473134
.. note:: Zero-importance features will not be included
31483135
3149-
Keep in mind that this function does not include zero-importance feature, i.e.
3150-
those features that have not been used in any split conditions.
3136+
Keep in mind that this function does not include zero-importance feature,
3137+
i.e. those features that have not been used in any split conditions.
31513138
31523139
Parameters
31533140
----------

src/c_api/c_api.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,12 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
330330
xgboost_CHECK_C_ARG_PTR(reset);
331331
xgboost_CHECK_C_ARG_PTR(out);
332332

333-
auto config = ExtMemConfig{
334-
cache, on_host, min_cache_page_bytes, missing, /*max_num_device_pages=*/0, n_threads};
333+
auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads};
335334
*out = new std::shared_ptr<xgboost::DMatrix>{
336335
xgboost::DMatrix::Create(iter, proxy, reset, next, config)};
337336
API_END();
338337
}
339338

340-
341339
namespace {
342340
std::shared_ptr<DMatrix> GetRefDMatrix(DataIterHandle ref) {
343341
std::shared_ptr<DMatrix> _ref{nullptr};
@@ -393,17 +391,14 @@ XGB_DLL int XGExtMemQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatr
393391
std::string cache = RequiredArg<String>(jconfig, "cache_prefix", __func__);
394392
auto min_cache_page_bytes = OptionalArg<Integer, std::int64_t>(jconfig, "min_cache_page_bytes",
395393
cuda_impl::AutoCachePageBytes());
396-
auto max_num_device_pages = OptionalArg<Integer, std::int64_t>(jconfig, "max_num_device_pages",
397-
cuda_impl::MaxNumDevicePages());
398394
auto max_quantile_blocks = OptionalArg<Integer, std::int64_t>(
399395
jconfig, "max_quantile_blocks", std::numeric_limits<std::int64_t>::max());
400396

401397
xgboost_CHECK_C_ARG_PTR(next);
402398
xgboost_CHECK_C_ARG_PTR(reset);
403399
xgboost_CHECK_C_ARG_PTR(out);
404400

405-
auto config =
406-
ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, max_num_device_pages, n_threads};
401+
auto config = ExtMemConfig{cache, on_host, min_cache_page_bytes, missing, n_threads};
407402
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
408403
iter, proxy, p_ref, reset, next, max_bin, max_quantile_blocks, config)};
409404
API_END();

src/data/batch_utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ void CheckParam(BatchParam const& init, BatchParam const& param);
3939
namespace xgboost::cuda_impl {
4040
// Indicator for XGBoost to not concatenate any page.
4141
constexpr std::int64_t MatchingPageBytes() { return 0; }
42-
// Maxmimum number of pages from the validation dataset to be cached in the device memory.
43-
constexpr std::int32_t MaxNumDevicePages() { return 1; }
4442
// Default size of the cached page
4543
constexpr double CachePageRatio() { return 0.125; }
4644
// Indicator for XGBoost to automatically concatenate pages.

src/data/data.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -950,12 +950,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
950950
CHECK(data_split_mode != DataSplitMode::kCol)
951951
<< "Column-wise data split is not supported for external memory.";
952952
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
953-
auto config = ExtMemConfig{cache_file,
954-
false,
955-
cuda_impl::MatchingPageBytes(),
956-
std::numeric_limits<float>::quiet_NaN(),
957-
cuda_impl::MaxNumDevicePages(),
958-
1};
953+
auto config = ExtMemConfig{cache_file, false, cuda_impl::MatchingPageBytes(),
954+
std::numeric_limits<float>::quiet_NaN(), 1};
959955
dmat = new data::SparsePageDMatrix{&iter, iter.Proxy(), data::fileiter::Reset,
960956
data::fileiter::Next, config};
961957
}

0 commit comments

Comments
 (0)