Skip to content

Commit d87788a

Browse files
authored
Merge pull request #176 from JuliaAI/dev
For a 0.7.1 release
2 parents ebf0983 + 72a9f8d commit d87788a

File tree

3 files changed

+133
-29
lines changed

3 files changed

+133
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"

src/tuned_models.jl

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ const ERR_MODEL_TYPE = ArgumentError(
2121
"Only `Deterministic` and `Probabilistic` model types supported.")
2222
const INFO_MODEL_IGNORED =
2323
"`model` being ignored. Using `model=first(range)`. "
24+
const ERR_TOO_MANY_ARGUMENTS =
25+
ArgumentError("At most one non-keyword argument allowed. ")
26+
warn_double_spec(arg, model) =
27+
"Using `model=$arg`. Ignoring keyword specification `model=$model`. "
2428

2529
const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector}
2630
const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector}
@@ -30,7 +34,8 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
3034
tuning::T # tuning strategy
3135
resampling # resampling strategy
3236
measure
33-
weights::Union{Nothing,Vector{<:Real}}
37+
weights::Union{Nothing,AbstractVector{<:Real}}
38+
class_weights::Union{Nothing,AbstractDict}
3439
operation
3540
range
3641
selection_heuristic
@@ -49,6 +54,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
4954
resampling # resampling strategy
5055
measure
5156
weights::Union{Nothing,AbstractVector{<:Real}}
57+
class_weights::Union{Nothing,AbstractDict}
5258
operation
5359
range
5460
selection_heuristic
@@ -64,7 +70,8 @@ end
6470
const EitherTunedModel{T,M} =
6571
Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}}
6672

67-
#todo update:
73+
MLJBase.caches_data_by_default(::Type{<:EitherTunedModel}) = false
74+
6875
"""
6976
tuned_model = TunedModel(; model=<model to be mutated>,
7077
tuning=RandomSearch(),
@@ -114,6 +121,8 @@ Calling `fit!(mach)` on a machine `mach=machine(tuned_model, X, y)` or
114121
internal machine. The final train can be supressed by setting
115122
`train_best=false`.
116123
124+
### Search space
125+
117126
The `range` objects supported depend on the `tuning` strategy
118127
specified. Query the `strategy` docstring for details. To optimize
119128
over an explicit list `v` of models of the same type, use
@@ -124,28 +133,26 @@ then `MLJTuning.default_n(tuning, range)` is used. When `n` is
124133
increased and `fit!(mach)` called again, the old search history is
125134
re-instated and the search continues where it left off.
126135
127-
If `measure` supports weights (`supports_weights(measure) == true`)
128-
then any `weights` specified will be passed to the measure. If more
129-
than one `measure` is specified, then only the first is optimized
130-
(unless `strategy` is multi-objective) but the performance against
131-
every measure specified will be computed and reported in
132-
`report(mach).best_performance` and other relevant attributes of the
133-
generated report.
136+
### Measures (metrics)
134137
135-
Specify `repeats > 1` for repeated resampling per model
136-
evaluation. See [`evaluate!`](@ref) options for details.
138+
If more than one `measure` is specified, then only the first is
139+
optimized (unless `strategy` is multi-objective) but the performance
140+
against every measure specified will be computed and reported in
141+
`report(mach).best_performance` and other relevant attributes of the
142+
generated report. Options exist to pass per-observation weights or
143+
class weights to measures; see below.
137144
138-
*Important.* If a custom `measure` is used, and the measure is
139-
a score, rather than a loss, be sure to check that
140-
`MLJ.orientation(measure) == :score` to ensure maximization of the
145+
*Important.* If a custom measure, `my_measure` is used, and the
146+
measure is a score, rather than a loss, be sure to check that
147+
`MLJ.orientation(my_measure) == :score` to ensure maximization of the
141148
measure, rather than minimization. Override an incorrect value with
142-
`MLJ.orientation(::typeof(measure)) = :score`.
149+
`MLJ.orientation(::typeof(my_measure)) = :score`.
150+
151+
### Accessing the fitted parameters and other training (tuning) outcomes
143152
144153
A Plots.jl plot of performance estimates is returned by `plot(mach)`
145154
or `heatmap(mach)`.
146155
147-
### Accessing the fitted parameters and other training (tuning) outcomes
148-
149156
Once a tuning machine `mach` has bee trained as above, then
150157
`fitted_params(mach)` has these keys/values:
151158
@@ -165,7 +172,7 @@ key | value
165172
166173
plus other key/value pairs specific to the `tuning` strategy.
167174
168-
### Summary of key-word arguments
175+
### Complete list of key-word options
169176
170177
- `model`: `Supervised` model prototype that is cloned and mutated to
171178
generate models for evaluation
@@ -185,11 +192,15 @@ plus other key/value pairs specific to the `tuning` strategy.
185192
evaluations; only the first used in optimization (unless the
186193
strategy is multi-objective) but all reported to the history
187194
188-
- `weights`: sample weights to be passed the measure(s) in performance
189-
evaluations, if supported.
195+
- `weights`: per-observation weights to be passed the measure(s) in performance
196+
evaluations, where supported. Check support with `supports_weights(measure)`.
197+
198+
- `class_weights`: class weights to be passed the measure(s) in
199+
performance evaluations, where supported. Check support with
200+
`supports_class_weights(measure)`.
190201
191202
- `repeats=1`: for generating train/test sets multiple times in
192-
resampling; see [`evaluate!`](@ref) for details
203+
resampling ("Monte Carlo" resampling); see [`evaluate!`](@ref) for details
193204
194205
- `operation`/`operations` - One of
195206
$(MLJBase.PREDICT_OPERATIONS_STRING), or a vector of these of the
@@ -226,13 +237,14 @@ plus other key/value pairs specific to the `tuning` strategy.
226237
likely limited to the case `resampling isa Holdout`.
227238
228239
"""
229-
function TunedModel(; model=nothing,
240+
function TunedModel(args...; model=nothing,
230241
models=nothing,
231242
tuning=nothing,
232243
resampling=MLJBase.Holdout(),
233244
measures=nothing,
234245
measure=measures,
235246
weights=nothing,
247+
class_weights=nothing,
236248
operations=nothing,
237249
operation=operations,
238250
ranges=nothing,
@@ -246,8 +258,17 @@ function TunedModel(; model=nothing,
246258
check_measure=true,
247259
cache=true)
248260

261+
# user can specify model as argument instead of kwarg:
262+
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
263+
if length(args) === 1
264+
arg = first(args)
265+
model === nothing ||
266+
@warn warn_double_spec(arg, model)
267+
model =arg
268+
end
269+
249270
# either `models` is specified and `tuning` is set to `Explicit`,
250-
# or `models` is unspecified and tuning will fallback to `Grid()`
271+
# or `models` is unspecified and tuning will fallback to `RandomSearch()`
251272
# unless it is itself specified:
252273
if models !== nothing
253274
if tuning === nothing
@@ -295,9 +316,24 @@ function TunedModel(; model=nothing,
295316
# get the tuning type parameter:
296317
T = typeof(tuning)
297318

298-
args = (model, tuning, resampling, measure, weights, operation, range,
299-
selection_heuristic, train_best, repeats, n, acceleration, acceleration_resampling,
300-
check_measure, cache)
319+
args = (
320+
model,
321+
tuning,
322+
resampling,
323+
measure,
324+
weights,
325+
class_weights,
326+
operation,
327+
range,
328+
selection_heuristic,
329+
train_best,
330+
repeats,
331+
n,
332+
acceleration,
333+
acceleration_resampling,
334+
check_measure,
335+
cache
336+
)
301337

302338
if M <: DeterministicTypes
303339
tuned_model = DeterministicTunedModel{T,M}(args...)
@@ -531,6 +567,7 @@ function assemble_events!(metamodels,
531567
resampling = resampling_machine.model.resampling,
532568
measure = resampling_machine.model.measure,
533569
weights = resampling_machine.model.weights,
570+
class_weights = resampling_machine.model.class_weights,
534571
operation = resampling_machine.model.operation,
535572
check_measure = resampling_machine.model.check_measure,
536573
repeats = resampling_machine.model.repeats,
@@ -693,6 +730,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
693730
resampling = deepcopy(tuned_model.resampling),
694731
measure = tuned_model.measure,
695732
weights = tuned_model.weights,
733+
class_weights = tuned_model.class_weights,
696734
operation = tuned_model.operation,
697735
check_measure = tuned_model.check_measure,
698736
repeats = tuned_model.repeats,
@@ -784,6 +822,8 @@ end
784822
MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true
785823
MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
786824
MLJBase.supports_weights(M)
825+
MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
826+
MLJBase.supports_class_weights(M)
787827
MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) =
788828
"MLJTuning.ProbabilisticTunedModel"
789829
MLJBase.load_path(::Type{<:DeterministicTunedModel}) =

test/tuned_models.jl

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ r = [m(K) for K in 13:-1:2]
3535
@test_throws(MLJTuning.ERR_BOTH_DISALLOWED,
3636
TunedModel(model=first(r),
3737
models=r, tuning=Explicit(), measure=rms))
38-
tm = TunedModel(models=r, tuning=Explicit(), measure=rms)
38+
tm = @test_logs TunedModel(models=r, tuning=Explicit(), measure=rms)
3939
@test tm.tuning isa Explicit && tm.range ==r && tm.model == first(r)
4040
@test input_scitype(tm) == Unknown
4141
@test TunedModel(models=r, measure=rms) == tm
@@ -54,7 +54,16 @@ r = [m(K) for K in 13:-1:2]
5454
TunedModel(tuning=Explicit(), measure=rms))
5555
@test_throws(MLJTuning.ERR_NEED_EXPLICIT,
5656
TunedModel(models=r, tuning=Grid()))
57-
tm = TunedModel(model=first(r), range=r, measure=rms)
57+
@test_logs TunedModel(first(r), range=r, measure=rms)
58+
@test_logs(
59+
(:warn, MLJTuning.warn_double_spec(first(r), last(r))),
60+
TunedModel(first(r), model=last(r), range=r, measure=rms),
61+
)
62+
@test_throws(
63+
MLJTuning.ERR_TOO_MANY_ARGUMENTS,
64+
TunedModel(first(r), last(r), range=r, measure=rms),
65+
)
66+
tm = @test_logs TunedModel(model=first(r), range=r, measure=rms)
5867
@test tm.tuning isa RandomSearch
5968
@test input_scitype(tm) == Table(Continuous)
6069
end
@@ -341,4 +350,59 @@ end
341350

342351
end
343352

353+
@testset_accelerated "weights and class_weights are being passed" accel begin
354+
# we'll be tuning using 50/50 holdout
355+
X = (x=fill(1.0, 6),)
356+
y = coerce(["a", "a", "b", "a", "a", "b"], OrderedFactor)
357+
w = [1.0, 1.0, 100.0, 1.0, 1.0, 100.0]
358+
class_w = Dict("a" => 2.0, "b" => 100.0)
359+
360+
model = DecisionTreeClassifier()
361+
362+
# the first supports weights, the second class weights:
363+
ms=[MisclassificationRate(), MulticlassFScore()]
364+
365+
resampling=Holdout(fraction_train=0.5)
366+
367+
# without weights:
368+
tmodel = TunedModel(
369+
resampling=resampling,
370+
models=fill(model, 5),
371+
measures=ms,
372+
acceleration=accel
373+
)
374+
mach = machine(tmodel, X, y)
375+
fit!(mach, verbosity=0)
376+
measurement = report(mach).best_history_entry.measurement
377+
e = evaluate(model, X, y, measures=ms, resampling=resampling)
378+
@test measurement == e.measurement
379+
380+
# with weights:
381+
tmodel.weights = w
382+
tmodel.class_weights = class_w
383+
fit!(mach, verbosity=0)
384+
measurement_weighted = report(mach).best_history_entry.measurement
385+
e_weighted = evaluate(model, X, y;
386+
measures=ms,
387+
resampling=resampling,
388+
weights=w,
389+
class_weights=class_w,
390+
verbosity=-1)
391+
@test measurement_weighted == e_weighted.measurement
392+
393+
# check both measures are different when they are weighted:
394+
@test !any(measurement .== measurement_weighted)
395+
end
396+
397+
@testset "data caching at outer level suppressed" begin
398+
X, y = make_blobs()
399+
model = DecisionTreeClassifier()
400+
tmodel = TunedModel(models=[model,])
401+
mach = machine(tmodel, X, y)
402+
@test mach isa Machine{<:Any,false}
403+
fit!(mach, verbosity=-1)
404+
@test !isdefined(mach, :data)
405+
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
406+
end
407+
344408
true

0 commit comments

Comments
 (0)