@@ -50,6 +50,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
50
50
acceleration_resampling:: AbstractResource
51
51
check_measure:: Bool
52
52
cache:: Bool
53
+ compact_history:: Bool
53
54
end
54
55
55
56
mutable struct ProbabilisticTunedModel{T,M<: ProbabilisticTypes } <: MLJBase.Probabilistic
@@ -69,6 +70,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
69
70
acceleration_resampling:: AbstractResource
70
71
check_measure:: Bool
71
72
cache:: Bool
73
+ compact_history:: Bool
72
74
end
73
75
74
76
const EitherTunedModel{T,M} =
@@ -176,6 +178,15 @@ key | value
176
178
177
179
plus other key/value pairs specific to the `tuning` strategy.
178
180
181
+ Each element of `history` is a property-accessible object with these properties:
182
+
183
+ key | value
184
+ --------------------|--------------------------------------------------
185
+ `measure` | vector of measures (metrics)
186
+ `measurement` | vector of measurements, one per measure
187
+ `per_fold` | vector of vectors of unaggregated per-fold measurements
188
+ `evaluation` | full `PerformanceEvaluation`/`CompactPerformaceEvaluation` object
189
+
179
190
### Complete list of key-word options
180
191
181
192
- `model`: `Supervised` model prototype that is cloned and mutated to
@@ -240,27 +251,35 @@ plus other key/value pairs specific to the `tuning` strategy.
240
251
user-suplied data; set to `false` to conserve memory. Speed gains
241
252
likely limited to the case `resampling isa Holdout`.
242
253
254
+ - `compact_history=true`: whether to write `CompactPerformanceEvaluation`](@ref) or
255
+ regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the
256
+ `:evaluation` key); the compact form excludes some fields to conserve memory.
257
+
243
258
"""
244
- function TunedModel (args... ; model= nothing ,
245
- models= nothing ,
246
- tuning= nothing ,
247
- resampling= MLJBase. Holdout (),
248
- measures= nothing ,
249
- measure= measures,
250
- weights= nothing ,
251
- class_weights= nothing ,
252
- operations= nothing ,
253
- operation= operations,
254
- ranges= nothing ,
255
- range= ranges,
256
- selection_heuristic= NaiveSelection (),
257
- train_best= true ,
258
- repeats= 1 ,
259
- n= nothing ,
260
- acceleration= default_resource (),
261
- acceleration_resampling= CPU1 (),
262
- check_measure= true ,
263
- cache= true )
259
+ function TunedModel (
260
+ args... ;
261
+ model= nothing ,
262
+ models= nothing ,
263
+ tuning= nothing ,
264
+ resampling= MLJBase. Holdout (),
265
+ measures= nothing ,
266
+ measure= measures,
267
+ weights= nothing ,
268
+ class_weights= nothing ,
269
+ operations= nothing ,
270
+ operation= operations,
271
+ ranges= nothing ,
272
+ range= ranges,
273
+ selection_heuristic= NaiveSelection (),
274
+ train_best= true ,
275
+ repeats= 1 ,
276
+ n= nothing ,
277
+ acceleration= default_resource (),
278
+ acceleration_resampling= CPU1 (),
279
+ check_measure= true ,
280
+ cache= true ,
281
+ compact_history= true ,
282
+ )
264
283
265
284
# user can specify model as argument instead of kwarg:
266
285
length (args) < 2 || throw (ERR_TOO_MANY_ARGUMENTS)
@@ -339,7 +358,8 @@ function TunedModel(args...; model=nothing,
339
358
acceleration,
340
359
acceleration_resampling,
341
360
check_measure,
342
- cache
361
+ cache,
362
+ compact_history,
343
363
)
344
364
345
365
if M <: DeterministicTypes
@@ -582,9 +602,10 @@ function assemble_events!(metamodels,
582
602
check_measure = resampling_machine. model. check_measure,
583
603
repeats = resampling_machine. model. repeats,
584
604
acceleration = resampling_machine. model. acceleration,
585
- cache = resampling_machine. model. cache),
586
- resampling_machine. args... ; cache= false ) for
587
- _ in 2 : length (partitions)]. .. ]
605
+ cache = resampling_machine. model. cache,
606
+ compact = resampling_machine. model. compact
607
+ ), resampling_machine. args... ; cache= false ) for
608
+ _ in 2 : length (partitions)]. .. ]
588
609
589
610
@sync for (i, parts) in enumerate (partitions)
590
611
Threads. @spawn begin
@@ -736,21 +757,23 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
736
757
737
758
# instantiate resampler (`model` to be replaced with mutated
738
759
# clones during iteration below):
739
- resampler = Resampler (model= model,
740
- resampling = deepcopy (tuned_model. resampling),
741
- measure = tuned_model. measure,
742
- weights = tuned_model. weights,
743
- class_weights = tuned_model. class_weights,
744
- operation = tuned_model. operation,
745
- check_measure = tuned_model. check_measure,
746
- repeats = tuned_model. repeats,
747
- acceleration = tuned_model. acceleration_resampling,
748
- cache = tuned_model. cache)
760
+ resampler = Resampler (
761
+ model= model,
762
+ resampling = deepcopy (tuned_model. resampling),
763
+ measure = tuned_model. measure,
764
+ weights = tuned_model. weights,
765
+ class_weights = tuned_model. class_weights,
766
+ operation = tuned_model. operation,
767
+ check_measure = tuned_model. check_measure,
768
+ repeats = tuned_model. repeats,
769
+ acceleration = tuned_model. acceleration_resampling,
770
+ cache = tuned_model. cache,
771
+ compact = tuned_model. compact_history,
772
+ )
749
773
resampling_machine = machine (resampler, data... ; cache= false )
750
774
history, state = build! (nothing , n, tuning, model, model_buffer, state,
751
775
verbosity, acceleration, resampling_machine)
752
776
753
-
754
777
return finalize (
755
778
tuned_model,
756
779
model_buffer,
@@ -867,9 +890,9 @@ function MLJBase.reports_feature_importances(model::EitherTunedModel)
867
890
end # This is needed in some cases (e.g tuning a `Pipeline`)
868
891
869
892
function MLJBase. feature_importances (:: EitherTunedModel , fitresult, report)
870
- # fitresult here is a machine created using the best_model obtained
893
+ # fitresult here is a machine created using the best_model obtained
871
894
# from the tuning process.
872
- # The line below will return `nothing` when the model being tuned doesn't
895
+ # The line below will return `nothing` when the model being tuned doesn't
873
896
# support feature_importances.
874
897
return MLJBase. feature_importances (fitresult)
875
898
end
0 commit comments