|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from typing import Any, Callable, List, Optional, Tuple |
| 16 | + |
15 | 17 | try:
|
16 | 18 | import xgboost
|
17 | 19 | except ImportError:
|
@@ -61,34 +63,82 @@ def predict(self, data, **kw):
|
61 | 63 | """
|
62 | 64 | raise NotImplementedError
|
63 | 65 |
|
64 |
| - def evaluation_matrices(validation_set, sample_weights, session=None, run_kwargs=None): |
65 |
| - """ |
66 |
| - Parameters |
67 |
| - ---------- |
68 |
| - validation_set: list of tuples |
69 |
| - Each tuple contains a validation dataset including input X and label y. |
70 |
| - E.g.: |
71 |
| - .. code-block:: python |
72 |
| - [(X_0, y_0), (X_1, y_1), ... ] |
73 |
| - sample_weights: list of arrays |
74 |
| - The weight vector for validation data. |
75 |
| - session: |
76 |
| - Session to run |
77 |
| - run_kwargs: |
78 |
| - kwargs for session.run |
79 |
| - Returns |
80 |
| - ------- |
81 |
| - evals: list of validation MarsDMatrix |
| 66 | + def wrap_evaluation_matrices( |
| 67 | + missing: float, |
| 68 | + X: Any, |
| 69 | + y: Any, |
| 70 | + sample_weight: Optional[Any], |
| 71 | + base_margin: Optional[Any], |
| 72 | + eval_set: Optional[List[Tuple[Any, Any]]], |
| 73 | + sample_weight_eval_set: Optional[List[Any]], |
| 74 | + base_margin_eval_set: Optional[List[Any]], |
| 75 | + label_transform: Callable = lambda x: x, |
| 76 | + ) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: |
| 77 | + """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. |
82 | 78 | """
|
83 |
| - evals = [] |
84 |
| - if validation_set is not None: |
85 |
| - assert isinstance(validation_set, list) |
86 |
| - for i, e in enumerate(validation_set): |
87 |
| - w = (sample_weights[i] |
88 |
| - if sample_weights is not None else None) |
89 |
| - dmat = MarsDMatrix(e[0], label=e[1], weight=w, |
90 |
| - session=session, run_kwargs=run_kwargs) |
91 |
| - evals.append((dmat, f'validation_{i}')) |
| 79 | + train_dmatrix = MarsDMatrix( |
| 80 | + data=X, |
| 81 | + label=label_transform(y), |
| 82 | + weight=sample_weight, |
| 83 | + base_margin=base_margin, |
| 84 | + missing=missing, |
| 85 | + ) |
| 86 | + |
| 87 | + n_validation = 0 if eval_set is None else len(eval_set) |
| 88 | + |
| 89 | + def validate_or_none(meta: Optional[List], name: str) -> List: |
| 90 | + if meta is None: |
| 91 | + return [None] * n_validation |
| 92 | + if len(meta) != n_validation: |
| 93 | + raise ValueError( |
| 94 | + f"{name}'s length does not equal `eval_set`'s length, " + |
| 95 | + f"expecting {n_validation}, got {len(meta)}" |
| 96 | + ) |
| 97 | + return meta |
| 98 | + |
| 99 | + if eval_set is not None: |
| 100 | + sample_weight_eval_set = validate_or_none( |
| 101 | + sample_weight_eval_set, "sample_weight_eval_set" |
| 102 | + ) |
| 103 | + base_margin_eval_set = validate_or_none( |
| 104 | + base_margin_eval_set, "base_margin_eval_set" |
| 105 | + ) |
| 106 | + |
| 107 | + evals = [] |
| 108 | + for i, (valid_X, valid_y) in enumerate(eval_set): |
| 109 | + # Skip the duplicated entry. |
| 110 | + if all( |
| 111 | + ( |
| 112 | + valid_X is X, valid_y is y, |
| 113 | + sample_weight_eval_set[i] is sample_weight, |
| 114 | + base_margin_eval_set[i] is base_margin, |
| 115 | + ) |
| 116 | + ): |
| 117 | + evals.append(train_dmatrix) |
| 118 | + else: |
| 119 | + m = MarsDMatrix( |
| 120 | + data=valid_X, |
| 121 | + label=label_transform(valid_y), |
| 122 | + weight=sample_weight_eval_set[i], |
| 123 | + base_margin=base_margin_eval_set[i], |
| 124 | + missing=missing, |
| 125 | + ) |
| 126 | + evals.append(m) |
| 127 | + nevals = len(evals) |
| 128 | + eval_names = [f"validation_{i}" for i in range(nevals)] |
| 129 | + evals = list(zip(evals, eval_names)) |
92 | 130 | else:
|
93 |
| - evals = None |
94 |
| - return evals |
| 131 | + if any( |
| 132 | + meta is not None |
| 133 | + for meta in [ |
| 134 | + sample_weight_eval_set, |
| 135 | + base_margin_eval_set, |
| 136 | + ] |
| 137 | + ): |
| 138 | + raise ValueError( |
| 139 | + "`eval_set` is not set but one of the other evaluation meta info is " |
| 140 | + "not None." |
| 141 | + ) |
| 142 | + evals = [] |
| 143 | + |
| 144 | + return train_dmatrix, evals |
0 commit comments