@@ -21,6 +21,10 @@ const ERR_MODEL_TYPE = ArgumentError(
21
21
" Only `Deterministic` and `Probabilistic` model types supported." )
22
22
const INFO_MODEL_IGNORED =
23
23
" `model` being ignored. Using `model=first(range)`. "
24
+ const ERR_TOO_MANY_ARGUMENTS =
25
+ ArgumentError (" At most one non-keyword argument allowed. " )
26
+ warn_double_spec (arg, model) =
27
+ " Using `model=$arg `. Ignoring keyword specification `model=$model `. "
24
28
25
29
const ProbabilisticTypes = Union{Probabilistic, MLJBase. MLJModelInterface. ProbabilisticDetector}
26
30
const DeterministicTypes = Union{Deterministic, MLJBase. MLJModelInterface. DeterministicDetector}
@@ -30,7 +34,8 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
30
34
tuning:: T # tuning strategy
31
35
resampling # resampling strategy
32
36
measure
33
- weights:: Union{Nothing,Vector{<:Real}}
37
+ weights:: Union{Nothing,AbstractVector{<:Real}}
38
+ class_weights:: Union{Nothing,AbstractDict}
34
39
operation
35
40
range
36
41
selection_heuristic
@@ -49,6 +54,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
49
54
resampling # resampling strategy
50
55
measure
51
56
weights:: Union{Nothing,AbstractVector{<:Real}}
57
+ class_weights:: Union{Nothing,AbstractDict}
52
58
operation
53
59
range
54
60
selection_heuristic
64
70
const EitherTunedModel{T,M} =
65
71
Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}}
66
72
67
- # todo update:
73
+ MLJBase. caches_data_by_default (:: Type{<:EitherTunedModel} ) = false
74
+
68
75
"""
69
76
tuned_model = TunedModel(; model=<model to be mutated>,
70
77
tuning=RandomSearch(),
@@ -114,6 +121,8 @@ Calling `fit!(mach)` on a machine `mach=machine(tuned_model, X, y)` or
114
121
internal machine. The final train can be supressed by setting
115
122
`train_best=false`.
116
123
124
+ ### Search space
125
+
117
126
The `range` objects supported depend on the `tuning` strategy
118
127
specified. Query the `strategy` docstring for details. To optimize
119
128
over an explicit list `v` of models of the same type, use
@@ -124,28 +133,26 @@ then `MLJTuning.default_n(tuning, range)` is used. When `n` is
124
133
increased and `fit!(mach)` called again, the old search history is
125
134
re-instated and the search continues where it left off.
126
135
127
- If `measure` supports weights (`supports_weights(measure) == true`)
128
- then any `weights` specified will be passed to the measure. If more
129
- than one `measure` is specified, then only the first is optimized
130
- (unless `strategy` is multi-objective) but the performance against
131
- every measure specified will be computed and reported in
132
- `report(mach).best_performance` and other relevant attributes of the
133
- generated report.
136
+ ### Measures (metrics)
134
137
135
- Specify `repeats > 1` for repeated resampling per model
136
- evaluation. See [`evaluate!`](@ref) options for details.
138
+ If more than one `measure` is specified, then only the first is
139
+ optimized (unless `strategy` is multi-objective) but the performance
140
+ against every measure specified will be computed and reported in
141
+ `report(mach).best_performance` and other relevant attributes of the
142
+ generated report. Options exist to pass per-observation weights or
143
+ class weights to measures; see below.
137
144
138
- *Important.* If a custom ` measure` is used, and the measure is
139
- a score, rather than a loss, be sure to check that
140
- `MLJ.orientation(measure ) == :score` to ensure maximization of the
145
+ *Important.* If a custom measure, `my_measure` is used, and the
146
+ measure is a score, rather than a loss, be sure to check that
147
+ `MLJ.orientation(my_measure ) == :score` to ensure maximization of the
141
148
measure, rather than minimization. Override an incorrect value with
142
- `MLJ.orientation(::typeof(measure)) = :score`.
149
+ `MLJ.orientation(::typeof(my_measure)) = :score`.
150
+
151
+ ### Accessing the fitted parameters and other training (tuning) outcomes
143
152
144
153
A Plots.jl plot of performance estimates is returned by `plot(mach)`
145
154
or `heatmap(mach)`.
146
155
147
- ### Accessing the fitted parameters and other training (tuning) outcomes
148
-
149
156
Once a tuning machine `mach` has bee trained as above, then
150
157
`fitted_params(mach)` has these keys/values:
151
158
@@ -165,7 +172,7 @@ key | value
165
172
166
173
plus other key/value pairs specific to the `tuning` strategy.
167
174
168
- ### Summary of key-word arguments
175
+ ### Complete list of key-word options
169
176
170
177
- `model`: `Supervised` model prototype that is cloned and mutated to
171
178
generate models for evaluation
@@ -185,11 +192,15 @@ plus other key/value pairs specific to the `tuning` strategy.
185
192
evaluations; only the first used in optimization (unless the
186
193
strategy is multi-objective) but all reported to the history
187
194
188
- - `weights`: sample weights to be passed the measure(s) in performance
189
- evaluations, if supported.
195
+ - `weights`: per-observation weights to be passed the measure(s) in performance
196
+ evaluations, where supported. Check support with `supports_weights(measure)`.
197
+
198
+ - `class_weights`: class weights to be passed the measure(s) in
199
+ performance evaluations, where supported. Check support with
200
+ `supports_class_weights(measure)`.
190
201
191
202
- `repeats=1`: for generating train/test sets multiple times in
192
- resampling; see [`evaluate!`](@ref) for details
203
+ resampling ("Monte Carlo" resampling) ; see [`evaluate!`](@ref) for details
193
204
194
205
- `operation`/`operations` - One of
195
206
$(MLJBase. PREDICT_OPERATIONS_STRING) , or a vector of these of the
@@ -226,13 +237,14 @@ plus other key/value pairs specific to the `tuning` strategy.
226
237
likely limited to the case `resampling isa Holdout`.
227
238
228
239
"""
229
- function TunedModel (; model= nothing ,
240
+ function TunedModel (args ... ; model= nothing ,
230
241
models= nothing ,
231
242
tuning= nothing ,
232
243
resampling= MLJBase. Holdout (),
233
244
measures= nothing ,
234
245
measure= measures,
235
246
weights= nothing ,
247
+ class_weights= nothing ,
236
248
operations= nothing ,
237
249
operation= operations,
238
250
ranges= nothing ,
@@ -246,8 +258,17 @@ function TunedModel(; model=nothing,
246
258
check_measure= true ,
247
259
cache= true )
248
260
261
+ # user can specify model as argument instead of kwarg:
262
+ length (args) < 2 || throw (ERR_TOO_MANY_ARGUMENTS)
263
+ if length (args) === 1
264
+ arg = first (args)
265
+ model === nothing ||
266
+ @warn warn_double_spec (arg, model)
267
+ model = arg
268
+ end
269
+
249
270
# either `models` is specified and `tuning` is set to `Explicit`,
250
- # or `models` is unspecified and tuning will fallback to `Grid ()`
271
+ # or `models` is unspecified and tuning will fallback to `RandomSearch ()`
251
272
# unless it is itself specified:
252
273
if models != = nothing
253
274
if tuning === nothing
@@ -295,9 +316,24 @@ function TunedModel(; model=nothing,
295
316
# get the tuning type parameter:
296
317
T = typeof (tuning)
297
318
298
- args = (model, tuning, resampling, measure, weights, operation, range,
299
- selection_heuristic, train_best, repeats, n, acceleration, acceleration_resampling,
300
- check_measure, cache)
319
+ args = (
320
+ model,
321
+ tuning,
322
+ resampling,
323
+ measure,
324
+ weights,
325
+ class_weights,
326
+ operation,
327
+ range,
328
+ selection_heuristic,
329
+ train_best,
330
+ repeats,
331
+ n,
332
+ acceleration,
333
+ acceleration_resampling,
334
+ check_measure,
335
+ cache
336
+ )
301
337
302
338
if M <: DeterministicTypes
303
339
tuned_model = DeterministicTunedModel {T,M} (args... )
@@ -531,6 +567,7 @@ function assemble_events!(metamodels,
531
567
resampling = resampling_machine. model. resampling,
532
568
measure = resampling_machine. model. measure,
533
569
weights = resampling_machine. model. weights,
570
+ class_weights = resampling_machine. model. class_weights,
534
571
operation = resampling_machine. model. operation,
535
572
check_measure = resampling_machine. model. check_measure,
536
573
repeats = resampling_machine. model. repeats,
@@ -693,6 +730,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
693
730
resampling = deepcopy (tuned_model. resampling),
694
731
measure = tuned_model. measure,
695
732
weights = tuned_model. weights,
733
+ class_weights = tuned_model. class_weights,
696
734
operation = tuned_model. operation,
697
735
check_measure = tuned_model. check_measure,
698
736
repeats = tuned_model. repeats,
784
822
MLJBase. is_wrapper (:: Type{<:EitherTunedModel} ) = true
785
823
MLJBase. supports_weights (:: Type{<:EitherTunedModel{<:Any,M}} ) where M =
786
824
MLJBase. supports_weights (M)
825
+ MLJBase. supports_class_weights (:: Type{<:EitherTunedModel{<:Any,M}} ) where M =
826
+ MLJBase. supports_class_weights (M)
787
827
MLJBase. load_path (:: Type{<:ProbabilisticTunedModel} ) =
788
828
" MLJTuning.ProbabilisticTunedModel"
789
829
MLJBase. load_path (:: Type{<:DeterministicTunedModel} ) =
0 commit comments