Skip to content

Commit 39d6cb4

Browse files
authored
Merge pull request #209 from JuliaAI/dev
For a 0.8.2 release
2 parents 60ad344 + c081bd0 commit 39d6cb4

File tree

3 files changed

+90
-10
lines changed

3 files changed

+90
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.8.1"
4+
version = "0.8.2"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"

src/tuned_models.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## TYPES AND CONSTRUCTOR
1+
# TYPES AND CONSTRUCTOR
22

33
const ERR_SPECIFY_MODEL = ArgumentError(
44
"You need to specify `model=...`, unless `tuning=Explicit()`. ")
@@ -687,7 +687,7 @@ function finalize(tuned_model,
687687
history,
688688
state,
689689
verbosity,
690-
rm,
690+
resampling_machine,
691691
data...)
692692
model = tuned_model.model
693693
tuning = tuned_model.tuning
@@ -713,7 +713,7 @@ function finalize(tuned_model,
713713
end
714714

715715
report = merge(report1, tuning_report(tuning, history, state))
716-
meta_state = (history, deepcopy(tuned_model), model_buffer, state, rm)
716+
meta_state = (history, deepcopy(tuned_model), model_buffer, state, resampling_machine)
717717

718718
return fitresult, meta_state, report
719719
end
@@ -749,9 +749,16 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
749749
history, state = build!(nothing, n, tuning, model, model_buffer, state,
750750
verbosity, acceleration, resampling_machine)
751751

752-
rm = resampling_machine
753-
return finalize(tuned_model, model_buffer,
754-
history, state, verbosity, rm, data...)
752+
753+
return finalize(
754+
tuned_model,
755+
model_buffer,
756+
history,
757+
state,
758+
verbosity,
759+
resampling_machine,
760+
data...,
761+
)
755762

756763
end
757764

@@ -784,9 +791,15 @@ function MLJBase.update(tuned_model::EitherTunedModel,
784791
history, state = build!(history, n!, tuning, model, model_buffer, state,
785792
verbosity, acceleration, resampling_machine)
786793

787-
rm = resampling_machine
788-
return finalize(tuned_model, model_buffer,
789-
history, state, verbosity, rm, data...)
794+
return finalize(
795+
tuned_model,
796+
model_buffer,
797+
history,
798+
state,
799+
verbosity,
800+
resampling_machine,
801+
data...,
802+
)
790803
else
791804
return fit(tuned_model, verbosity, data...)
792805
end
@@ -806,6 +819,24 @@ function MLJBase.fitted_params(tuned_model::EitherTunedModel, fitresult)
806819
end
807820

808821

822+
## FORWARD SERIALIZATION METHODS FROM ATOMIC MODEL
823+
824+
const ERR_SERIALIZATION = ErrorException(
825+
"Attempting to serialize a `TunedModel` instance whose best model has not "*
826+
"been trained. It appears as if it was trained with `train_best=false`. "*
827+
"Try re-training using `train_best=true`. "
828+
)
829+
830+
# `fitresult` is `machine(best_model, data...)`, trained iff `train_best` hyperparameter
831+
# is `true`.
832+
function MLJBase.save(tmodel::EitherTunedModel, fitresult)
833+
MLJBase.age(fitresult) > 0 || throw(ERR_SERIALIZATION)
834+
return MLJBase.serializable(fitresult)
835+
end
836+
MLJBase.restore(tmodel::EitherTunedModel, serializable_fitresult) =
837+
MLJBase.restore!(serializable_fitresult)
838+
839+
809840
## SUPPORT FOR MLJ ITERATION API
810841

811842
MLJBase.iteration_parameter(::Type{<:EitherTunedModel}) = :n

test/tuned_models.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,53 @@ end
406406
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
407407
end
408408

409+
# define a supervised model with ephemeral `fitresult`, but which overcomes this by
410+
# overloading `save`/`restore`:
411+
thing = []
412+
struct EphemeralRegressor <: Deterministic end
413+
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
414+
# if I serialize/deserialized `thing` then `id` below changes:
415+
id = objectid(thing)
416+
fitresult = (thing, id, mean(y))
417+
return fitresult, nothing, NamedTuple()
418+
end
419+
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
420+
thing, id, μ = fitresult
421+
return id == objectid(thing) ? fill(μ, nrows(X)) :
422+
throw(ErrorException("dead fitresult"))
423+
end
424+
function MLJBase.save(::EphemeralRegressor, fitresult)
425+
thing, _, μ = fitresult
426+
return (thing, μ)
427+
end
428+
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
429+
thing, μ = serialized_fitresult
430+
id = objectid(thing)
431+
return (thing, id, μ)
432+
end
433+
434+
@testset "save and restore" begin
435+
# https://github.com/JuliaAI/MLJTuning.jl/issues/207
436+
X, y = (; x = rand(10)), fill(42.0, 3)
437+
tmodel = TunedModel(
438+
models=fill(EphemeralRegressor(), 2),
439+
measure=l2,
440+
resampling=Holdout(),
441+
train_best=false,
442+
)
443+
mach = machine(tmodel, X, y)
444+
fit!(mach, verbosity=0)
445+
io = IOBuffer()
446+
@test_throws MLJTuning.ERR_SERIALIZATION MLJBase.save(io, mach)
447+
close(io)
448+
tmodel.train_best = true
449+
fit!(mach, verbosity=0)
450+
io = IOBuffer()
451+
@test_logs MLJBase.save(io, mach)
452+
seekstart(io)
453+
mach2 = machine(io)
454+
close(io)
455+
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
456+
end
457+
409458
true

0 commit comments

Comments
 (0)