Skip to content

Commit e367ea3

Browse files
author
Xuye (Chris) Qin
authored
Refine MarsDMatrix & support more parameters for XGB classifier and regressor (#2498)
1 parent 3a57fe7 commit e367ea3

File tree

10 files changed

+372
-236
lines changed

10 files changed

+372
-236
lines changed

.github/workflows/core-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
source ./ci/reload-env.sh
4040
export DEFAULT_VENV=$VIRTUAL_ENV
4141
42-
if [[ ! "$PYTHON" =~ "3.9" ]]; then
42+
if [[ ! "$PYTHON" =~ "3.6" ]]; then
4343
conda install -n test --quiet --yes -c conda-forge python=$PYTHON numba
4444
fi
4545

mars/learn/contrib/xgboost/classifier.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
from xgboost.sklearn import XGBClassifierBase
2222

2323
from .... import tensor as mt
24-
from .dmatrix import MarsDMatrix
25-
from .core import evaluation_matrices
24+
from .core import wrap_evaluation_matrices
2625
from .train import train
2726
from .predict import predict
2827

@@ -31,14 +30,16 @@ class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase):
3130
Implementation of the scikit-learn API for XGBoost classification.
3231
"""
3332

34-
def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=None, **kw):
33+
def fit(self, X, y, sample_weight=None, base_margin=None,
34+
eval_set=None, sample_weight_eval_set=None, base_margin_eval_set=None, **kw):
3535
session = kw.pop('session', None)
3636
run_kwargs = kw.pop('run_kwargs', dict())
3737
if kw:
3838
raise TypeError(f"fit got an unexpected keyword argument '{next(iter(kw))}'")
3939

40-
dtrain = MarsDMatrix(X, label=y, weight=sample_weights,
41-
session=session, run_kwargs=run_kwargs)
40+
dtrain, evals = wrap_evaluation_matrices(
41+
None, X, y, sample_weight, base_margin, eval_set,
42+
sample_weight_eval_set, base_margin_eval_set)
4243
params = self.get_xgb_params()
4344

4445
self.classes_ = mt.unique(y, aggregate_size=1).to_numpy(session=session, **run_kwargs)
@@ -50,8 +51,6 @@ def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=N
5051
else:
5152
params['objective'] = 'binary:logistic'
5253

53-
evals = evaluation_matrices(eval_set, sample_weight_eval_set,
54-
session=session, run_kwargs=run_kwargs)
5554
self.evals_result_ = dict()
5655
result = train(params, dtrain, num_boost_round=self.get_num_boosting_rounds(),
5756
evals=evals, evals_result=self.evals_result_,

mars/learn/contrib/xgboost/core.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Any, Callable, List, Optional, Tuple
16+
1517
try:
1618
import xgboost
1719
except ImportError:
@@ -61,34 +63,82 @@ def predict(self, data, **kw):
6163
"""
6264
raise NotImplementedError
6365

64-
def evaluation_matrices(validation_set, sample_weights, session=None, run_kwargs=None):
65-
"""
66-
Parameters
67-
----------
68-
validation_set: list of tuples
69-
Each tuple contains a validation dataset including input X and label y.
70-
E.g.:
71-
.. code-block:: python
72-
[(X_0, y_0), (X_1, y_1), ... ]
73-
sample_weights: list of arrays
74-
The weight vector for validation data.
75-
session:
76-
Session to run
77-
run_kwargs:
78-
kwargs for session.run
79-
Returns
80-
-------
81-
evals: list of validation MarsDMatrix
66+
def wrap_evaluation_matrices(
67+
missing: float,
68+
X: Any,
69+
y: Any,
70+
sample_weight: Optional[Any],
71+
base_margin: Optional[Any],
72+
eval_set: Optional[List[Tuple[Any, Any]]],
73+
sample_weight_eval_set: Optional[List[Any]],
74+
base_margin_eval_set: Optional[List[Any]],
75+
label_transform: Callable = lambda x: x,
76+
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
77+
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
8278
"""
83-
evals = []
84-
if validation_set is not None:
85-
assert isinstance(validation_set, list)
86-
for i, e in enumerate(validation_set):
87-
w = (sample_weights[i]
88-
if sample_weights is not None else None)
89-
dmat = MarsDMatrix(e[0], label=e[1], weight=w,
90-
session=session, run_kwargs=run_kwargs)
91-
evals.append((dmat, f'validation_{i}'))
79+
train_dmatrix = MarsDMatrix(
80+
data=X,
81+
label=label_transform(y),
82+
weight=sample_weight,
83+
base_margin=base_margin,
84+
missing=missing,
85+
)
86+
87+
n_validation = 0 if eval_set is None else len(eval_set)
88+
89+
def validate_or_none(meta: Optional[List], name: str) -> List:
90+
if meta is None:
91+
return [None] * n_validation
92+
if len(meta) != n_validation:
93+
raise ValueError(
94+
f"{name}'s length does not equal `eval_set`'s length, " +
95+
f"expecting {n_validation}, got {len(meta)}"
96+
)
97+
return meta
98+
99+
if eval_set is not None:
100+
sample_weight_eval_set = validate_or_none(
101+
sample_weight_eval_set, "sample_weight_eval_set"
102+
)
103+
base_margin_eval_set = validate_or_none(
104+
base_margin_eval_set, "base_margin_eval_set"
105+
)
106+
107+
evals = []
108+
for i, (valid_X, valid_y) in enumerate(eval_set):
109+
# Skip the duplicated entry.
110+
if all(
111+
(
112+
valid_X is X, valid_y is y,
113+
sample_weight_eval_set[i] is sample_weight,
114+
base_margin_eval_set[i] is base_margin,
115+
)
116+
):
117+
evals.append(train_dmatrix)
118+
else:
119+
m = MarsDMatrix(
120+
data=valid_X,
121+
label=label_transform(valid_y),
122+
weight=sample_weight_eval_set[i],
123+
base_margin=base_margin_eval_set[i],
124+
missing=missing,
125+
)
126+
evals.append(m)
127+
nevals = len(evals)
128+
eval_names = [f"validation_{i}" for i in range(nevals)]
129+
evals = list(zip(evals, eval_names))
92130
else:
93-
evals = None
94-
return evals
131+
if any(
132+
meta is not None
133+
for meta in [
134+
sample_weight_eval_set,
135+
base_margin_eval_set,
136+
]
137+
):
138+
raise ValueError(
139+
"`eval_set` is not set but one of the other evaluation meta info is "
140+
"not None."
141+
)
142+
evals = []
143+
144+
return train_dmatrix, evals

0 commit comments

Comments
 (0)