1
- # # TYPES AND CONSTRUCTOR
1
+ # TYPES AND CONSTRUCTOR
2
2
3
3
const ERR_SPECIFY_MODEL = ArgumentError (
4
4
" You need to specify `model=...`, unless `tuning=Explicit()`. " )
@@ -687,7 +687,7 @@ function finalize(tuned_model,
687
687
history,
688
688
state,
689
689
verbosity,
690
- rm ,
690
+ resampling_machine ,
691
691
data... )
692
692
model = tuned_model. model
693
693
tuning = tuned_model. tuning
@@ -713,7 +713,7 @@ function finalize(tuned_model,
713
713
end
714
714
715
715
report = merge (report1, tuning_report (tuning, history, state))
716
- meta_state = (history, deepcopy (tuned_model), model_buffer, state, rm )
716
+ meta_state = (history, deepcopy (tuned_model), model_buffer, state, resampling_machine )
717
717
718
718
return fitresult, meta_state, report
719
719
end
@@ -749,9 +749,16 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
749
749
history, state = build! (nothing , n, tuning, model, model_buffer, state,
750
750
verbosity, acceleration, resampling_machine)
751
751
752
- rm = resampling_machine
753
- return finalize (tuned_model, model_buffer,
754
- history, state, verbosity, rm, data... )
752
+
753
+ return finalize (
754
+ tuned_model,
755
+ model_buffer,
756
+ history,
757
+ state,
758
+ verbosity,
759
+ resampling_machine,
760
+ data... ,
761
+ )
755
762
756
763
end
757
764
@@ -784,9 +791,15 @@ function MLJBase.update(tuned_model::EitherTunedModel,
784
791
history, state = build! (history, n!, tuning, model, model_buffer, state,
785
792
verbosity, acceleration, resampling_machine)
786
793
787
- rm = resampling_machine
788
- return finalize (tuned_model, model_buffer,
789
- history, state, verbosity, rm, data... )
794
+ return finalize (
795
+ tuned_model,
796
+ model_buffer,
797
+ history,
798
+ state,
799
+ verbosity,
800
+ resampling_machine,
801
+ data... ,
802
+ )
790
803
else
791
804
return fit (tuned_model, verbosity, data... )
792
805
end
@@ -806,6 +819,24 @@ function MLJBase.fitted_params(tuned_model::EitherTunedModel, fitresult)
806
819
end
807
820
808
821
822
+ # # FORWARD SERIALIZATION METHODS FROM ATOMIC MODEL
823
+
824
+ const ERR_SERIALIZATION = ErrorException (
825
+ " Attempting to serialize a `TunedModel` instance whose best model has not " *
826
+ " been trained. It appears as if it was trained with `train_best=false`. " *
827
+ " Try re-training using `train_best=true`. "
828
+ )
829
+
830
+ # `fitresult` is `machine(best_model, data...)`, trained iff `train_best` hyperparameter
831
+ # is `true`.
832
+ function MLJBase. save (tmodel:: EitherTunedModel , fitresult)
833
+ MLJBase. age (fitresult) > 0 || throw (ERR_SERIALIZATION)
834
+ return MLJBase. serializable (fitresult)
835
+ end
836
+ MLJBase. restore (tmodel:: EitherTunedModel , serializable_fitresult) =
837
+ MLJBase. restore! (serializable_fitresult)
838
+
839
+
809
840
# # SUPPORT FOR MLJ ITERATION API
810
841
811
842
MLJBase. iteration_parameter (:: Type{<:EitherTunedModel} ) = :n
0 commit comments