@@ -33,7 +33,7 @@ warn_double_spec(arg, model) =
33
33
const ProbabilisticTypes = Union{Probabilistic, MLJBase. MLJModelInterface. ProbabilisticDetector}
34
34
const DeterministicTypes = Union{Deterministic, MLJBase. MLJModelInterface. DeterministicDetector}
35
35
36
- mutable struct DeterministicTunedModel{T,M<: DeterministicTypes } <: MLJBase.Deterministic
36
+ mutable struct DeterministicTunedModel{T,M<: DeterministicTypes ,L } <: MLJBase.Deterministic
37
37
model:: M
38
38
tuning:: T # tuning strategy
39
39
resampling # resampling strategy
@@ -51,9 +51,10 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
51
51
check_measure:: Bool
52
52
cache:: Bool
53
53
compact_history:: Bool
54
+ logger:: L
54
55
end
55
56
56
- mutable struct ProbabilisticTunedModel{T,M<: ProbabilisticTypes } <: MLJBase.Probabilistic
57
+ mutable struct ProbabilisticTunedModel{T,M<: ProbabilisticTypes ,L } <: MLJBase.Probabilistic
57
58
model:: M
58
59
tuning:: T # tuning strategy
59
60
resampling # resampling strategy
@@ -71,10 +72,11 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
71
72
check_measure:: Bool
72
73
cache:: Bool
73
74
compact_history:: Bool
75
+ logger:: L
74
76
end
75
77
76
- const EitherTunedModel{T,M} =
77
- Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}}
78
+ const EitherTunedModel{T,M,L } =
79
+ Union{DeterministicTunedModel{T,M,L },ProbabilisticTunedModel{T,M,L }}
78
80
79
81
MLJBase. caches_data_by_default (:: Type{<:EitherTunedModel} ) = false
80
82
@@ -279,6 +281,7 @@ function TunedModel(
279
281
check_measure= true ,
280
282
cache= true ,
281
283
compact_history= true ,
284
+ logger= nothing
282
285
)
283
286
284
287
# user can specify model as argument instead of kwarg:
@@ -342,6 +345,9 @@ function TunedModel(
342
345
# get the tuning type parameter:
343
346
T = typeof (tuning)
344
347
348
+ # get the logger type parameter:
349
+ L = typeof (logger)
350
+
345
351
args = (
346
352
model,
347
353
tuning,
@@ -360,12 +366,13 @@ function TunedModel(
360
366
check_measure,
361
367
cache,
362
368
compact_history,
369
+ logger
363
370
)
364
371
365
372
if M <: DeterministicTypes
366
- tuned_model = DeterministicTunedModel {T,M} (args... )
373
+ tuned_model = DeterministicTunedModel {T,M,L } (args... )
367
374
elseif M <: ProbabilisticTypes
368
- tuned_model = ProbabilisticTunedModel {T,M} (args... )
375
+ tuned_model = ProbabilisticTunedModel {T,M,L } (args... )
369
376
else
370
377
throw (ERR_MODEL_TYPE)
371
378
end
@@ -591,7 +598,7 @@ function assemble_events!(metamodels,
591
598
end
592
599
end
593
600
# One resampling_machine per task
594
- machs = [resampling_machine,
601
+ machs = [resampling_machine,
595
602
[machine (Resampler (
596
603
model= resampling_machine. model. model,
597
604
resampling = resampling_machine. model. resampling,
@@ -603,9 +610,9 @@ function assemble_events!(metamodels,
603
610
repeats = resampling_machine. model. repeats,
604
611
acceleration = resampling_machine. model. acceleration,
605
612
cache = resampling_machine. model. cache,
606
- compact = resampling_machine. model. compact
607
- ), resampling_machine. args ... ; cache = false ) for
608
- _ in 2 : length (partitions)]. .. ]
613
+ compact = resampling_machine. model. compact,
614
+ logger = resampling_machine. model . logger),
615
+ resampling_machine . args ... ; cache = false ) for _ in 2 : length (partitions)]. .. ]
609
616
610
617
@sync for (i, parts) in enumerate (partitions)
611
618
Threads. @spawn begin
@@ -740,8 +747,8 @@ function finalize(tuned_model,
740
747
return fitresult, meta_state, report
741
748
end
742
749
743
- function MLJBase. fit (tuned_model:: EitherTunedModel{T,M} ,
744
- verbosity:: Integer , data... ) where {T,M}
750
+ function MLJBase. fit (tuned_model:: EitherTunedModel{T,M,L } ,
751
+ verbosity:: Integer , data... ) where {T,M,L }
745
752
tuning = tuned_model. tuning
746
753
model = tuned_model. model
747
754
_range = tuned_model. range
@@ -769,6 +776,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
769
776
acceleration = tuned_model. acceleration_resampling,
770
777
cache = tuned_model. cache,
771
778
compact = tuned_model. compact_history,
779
+ logger = tuned_model. logger
772
780
)
773
781
resampling_machine = machine (resampler, data... ; cache= false )
774
782
history, state = build! (nothing , n, tuning, model, model_buffer, state,
900
908
# # METADATA
901
909
902
910
MLJBase. is_wrapper (:: Type{<:EitherTunedModel} ) = true
903
- MLJBase. supports_weights (:: Type{<:EitherTunedModel{<:Any,M}} ) where M =
911
+ MLJBase. supports_weights (:: Type{<:EitherTunedModel{<:Any,M,L }} ) where {M,L} =
904
912
MLJBase. supports_weights (M)
905
- MLJBase. supports_class_weights (:: Type{<:EitherTunedModel{<:Any,M}} ) where M =
913
+ MLJBase. supports_class_weights (:: Type{<:EitherTunedModel{<:Any,M,L }} ) where {M,L} =
906
914
MLJBase. supports_class_weights (M)
907
915
MLJBase. load_path (:: Type{<:ProbabilisticTunedModel} ) =
908
916
" MLJTuning.ProbabilisticTunedModel"
@@ -914,9 +922,9 @@ MLJBase.package_uuid(::Type{<:EitherTunedModel}) =
914
922
MLJBase. package_url (:: Type{<:EitherTunedModel} ) =
915
923
" https://github.com/alan-turing-institute/MLJTuning.jl"
916
924
MLJBase. package_license (:: Type{<:EitherTunedModel} ) = " MIT"
917
- MLJBase. is_pure_julia (:: Type{<:EitherTunedModel{T,M}} ) where {T,M} =
925
+ MLJBase. is_pure_julia (:: Type{<:EitherTunedModel{T,M,L }} ) where {T,M,L } =
918
926
MLJBase. is_pure_julia (M)
919
- MLJBase. input_scitype (:: Type{<:EitherTunedModel{T,M}} ) where {T,M} =
927
+ MLJBase. input_scitype (:: Type{<:EitherTunedModel{T,M,L }} ) where {T,M,L } =
920
928
MLJBase. input_scitype (M)
921
- MLJBase. target_scitype (:: Type{<:EitherTunedModel{T,M}} ) where {T,M} =
929
+ MLJBase. target_scitype (:: Type{<:EitherTunedModel{T,M,L }} ) where {T,M,L } =
922
930
MLJBase. target_scitype (M)
0 commit comments