Skip to content

Commit bac0ac9

Browse files
authored
Merge pull request #216 from JuliaAI/dev
For a 0.8.5 release
2 parents 9fe1f52 + 1dd32b1 commit bac0ac9

File tree

3 files changed

+87
-50
lines changed

3 files changed

+87
-50
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.8.4"
4+
version = "0.8.5"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
@@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
1818
ComputationalResources = "0.3"
1919
Distributions = "0.22,0.23,0.24, 0.25"
2020
LatinHypercubeSampling = "1.7.2"
21-
MLJBase = "1"
21+
MLJBase = "1.3"
2222
ProgressMeter = "1.7.1"
2323
RecipesBase = "0.8,0.9,1"
2424
StatisticalMeasuresBase = "0.1.1"

src/tuned_models.jl

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
5050
acceleration_resampling::AbstractResource
5151
check_measure::Bool
5252
cache::Bool
53+
compact_history::Bool
5354
end
5455

5556
mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic
@@ -69,6 +70,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
6970
acceleration_resampling::AbstractResource
7071
check_measure::Bool
7172
cache::Bool
73+
compact_history::Bool
7274
end
7375

7476
const EitherTunedModel{T,M} =
@@ -176,6 +178,15 @@ key | value
176178
177179
plus other key/value pairs specific to the `tuning` strategy.
178180
181+
Each element of `history` is a property-accessible object with these properties:
182+
183+
key | value
184+
--------------------|--------------------------------------------------
185+
`measure` | vector of measures (metrics)
186+
`measurement` | vector of measurements, one per measure
187+
`per_fold` | vector of vectors of unaggregated per-fold measurements
188+
`evaluation` | full `PerformanceEvaluation`/`CompactPerformaceEvaluation` object
189+
179190
### Complete list of key-word options
180191
181192
- `model`: `Supervised` model prototype that is cloned and mutated to
@@ -240,27 +251,35 @@ plus other key/value pairs specific to the `tuning` strategy.
240251
user-suplied data; set to `false` to conserve memory. Speed gains
241252
likely limited to the case `resampling isa Holdout`.
242253
254+
- `compact_history=true`: whether to write `CompactPerformanceEvaluation`](@ref) or
255+
regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the
256+
`:evaluation` key); the compact form excludes some fields to conserve memory.
257+
243258
"""
244-
function TunedModel(args...; model=nothing,
245-
models=nothing,
246-
tuning=nothing,
247-
resampling=MLJBase.Holdout(),
248-
measures=nothing,
249-
measure=measures,
250-
weights=nothing,
251-
class_weights=nothing,
252-
operations=nothing,
253-
operation=operations,
254-
ranges=nothing,
255-
range=ranges,
256-
selection_heuristic=NaiveSelection(),
257-
train_best=true,
258-
repeats=1,
259-
n=nothing,
260-
acceleration=default_resource(),
261-
acceleration_resampling=CPU1(),
262-
check_measure=true,
263-
cache=true)
259+
function TunedModel(
260+
args...;
261+
model=nothing,
262+
models=nothing,
263+
tuning=nothing,
264+
resampling=MLJBase.Holdout(),
265+
measures=nothing,
266+
measure=measures,
267+
weights=nothing,
268+
class_weights=nothing,
269+
operations=nothing,
270+
operation=operations,
271+
ranges=nothing,
272+
range=ranges,
273+
selection_heuristic=NaiveSelection(),
274+
train_best=true,
275+
repeats=1,
276+
n=nothing,
277+
acceleration=default_resource(),
278+
acceleration_resampling=CPU1(),
279+
check_measure=true,
280+
cache=true,
281+
compact_history=true,
282+
)
264283

265284
# user can specify model as argument instead of kwarg:
266285
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
@@ -339,7 +358,8 @@ function TunedModel(args...; model=nothing,
339358
acceleration,
340359
acceleration_resampling,
341360
check_measure,
342-
cache
361+
cache,
362+
compact_history,
343363
)
344364

345365
if M <: DeterministicTypes
@@ -582,9 +602,10 @@ function assemble_events!(metamodels,
582602
check_measure = resampling_machine.model.check_measure,
583603
repeats = resampling_machine.model.repeats,
584604
acceleration = resampling_machine.model.acceleration,
585-
cache = resampling_machine.model.cache),
586-
resampling_machine.args...; cache=false) for
587-
_ in 2:length(partitions)]...]
605+
cache = resampling_machine.model.cache,
606+
compact = resampling_machine.model.compact
607+
), resampling_machine.args...; cache=false) for
608+
_ in 2:length(partitions)]...]
588609

589610
@sync for (i, parts) in enumerate(partitions)
590611
Threads.@spawn begin
@@ -736,21 +757,23 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
736757

737758
# instantiate resampler (`model` to be replaced with mutated
738759
# clones during iteration below):
739-
resampler = Resampler(model=model,
740-
resampling = deepcopy(tuned_model.resampling),
741-
measure = tuned_model.measure,
742-
weights = tuned_model.weights,
743-
class_weights = tuned_model.class_weights,
744-
operation = tuned_model.operation,
745-
check_measure = tuned_model.check_measure,
746-
repeats = tuned_model.repeats,
747-
acceleration = tuned_model.acceleration_resampling,
748-
cache = tuned_model.cache)
760+
resampler = Resampler(
761+
model=model,
762+
resampling = deepcopy(tuned_model.resampling),
763+
measure = tuned_model.measure,
764+
weights = tuned_model.weights,
765+
class_weights = tuned_model.class_weights,
766+
operation = tuned_model.operation,
767+
check_measure = tuned_model.check_measure,
768+
repeats = tuned_model.repeats,
769+
acceleration = tuned_model.acceleration_resampling,
770+
cache = tuned_model.cache,
771+
compact = tuned_model.compact_history,
772+
)
749773
resampling_machine = machine(resampler, data...; cache=false)
750774
history, state = build!(nothing, n, tuning, model, model_buffer, state,
751775
verbosity, acceleration, resampling_machine)
752776

753-
754777
return finalize(
755778
tuned_model,
756779
model_buffer,
@@ -867,9 +890,9 @@ function MLJBase.reports_feature_importances(model::EitherTunedModel)
867890
end # This is needed in some cases (e.g tuning a `Pipeline`)
868891

869892
function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
870-
# fitresult here is a machine created using the best_model obtained
893+
# fitresult here is a machine created using the best_model obtained
871894
# from the tuning process.
872-
# The line below will return `nothing` when the model being tuned doesn't
895+
# The line below will return `nothing` when the model being tuned doesn't
873896
# support feature_importances.
874897
return MLJBase.feature_importances(fitresult)
875898
end

test/tuned_models.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using Random
1313
Random.seed!(1234*myid())
1414
using .TestUtilities
1515

16-
begin
16+
begin
1717
N = 30
1818
x1 = rand(N);
1919
x2 = rand(N);
@@ -157,14 +157,14 @@ end
157157

158158
@testset_accelerated "Feature Importances" accel begin
159159
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
160-
tm0 = TunedModel(
161-
model = trees[1],
162-
measure = rms,
163-
tuning = Grid(),
164-
resampling = CV(nfolds = 5),
165-
range = range(
166-
trees[1], :max_depth, values = 1:10
167-
)
160+
tm0 = TunedModel(
161+
model = trees[1],
162+
measure = rms,
163+
tuning = Grid(),
164+
resampling = CV(nfolds = 5),
165+
range = range(
166+
trees[1], :max_depth, values = 1:10
167+
)
168168
)
169169
@test reports_feature_importances(typeof(tm0))
170170
tm = TunedModel(
@@ -435,7 +435,7 @@ end
435435
model = DecisionTreeClassifier()
436436
tmodel = TunedModel(models=[model,])
437437
mach = machine(tmodel, X, y)
438-
@test mach isa Machine{<:Any,false}
438+
@test !MLJBase.caches_data(mach)
439439
fit!(mach, verbosity=-1)
440440
@test !isdefined(mach, :data)
441441
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
@@ -490,7 +490,7 @@ end
490490
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
491491
end
492492

493-
@testset_accelerated "full evaluation object" accel begin
493+
@testset_accelerated "evaluation object" accel begin
494494
X, y = make_regression(100, 2)
495495
dcr = DeterministicConstantRegressor()
496496

@@ -504,10 +504,24 @@ end
504504
fit!(homach, verbosity=0);
505505
horep = report(homach)
506506
evaluations = getproperty.(horep.history, :evaluation)
507+
@test first(evaluations) isa MLJBase.CompactPerformanceEvaluation
507508
measurements = getproperty.(evaluations, :measurement)
508509
models = getproperty.(evaluations, :model)
509510
@test all(==(measurements[1]), measurements)
510511
@test all(==(dcr), models)
512+
513+
homodel = TunedModel(
514+
models=fill(dcr, 10),
515+
resampling=Holdout(rng=StableRNG(1234)),
516+
acceleration_resampling=accel,
517+
measure=mae,
518+
compact_history=false,
519+
)
520+
homach = machine(homodel, X, y)
521+
fit!(homach, verbosity=0);
522+
horep = report(homach)
523+
evaluations = getproperty.(horep.history, :evaluation)
524+
@test first(evaluations) isa MLJBase.PerformanceEvaluation
511525
end
512526

513527
true

0 commit comments

Comments
 (0)