Skip to content

Commit e938e5b

Browse files
authored
Merge pull request #112 from alan-turing-institute/dev
For a 0.6.1 release
2 parents fed846c + 4c611f8 commit e938e5b

File tree

7 files changed

+182
-53
lines changed

7 files changed

+182
-53
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
@@ -18,8 +18,8 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1818
ComputationalResources = "^0.3"
1919
Distributions = "^0.22,^0.23,^0.24"
2020
LatinHypercubeSampling = "^1.7.2"
21-
MLJBase = "^0.15,^0.16"
22-
MLJModelInterface = "^0.3"
21+
MLJBase = "^0.17"
22+
MLJModelInterface = "^0.3.7,^0.4"
2323
ProgressMeter = "^1.3"
2424
RecipesBase = "^0.8,^0.9,^1"
2525
julia = "^1"

src/tuned_models.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ mutable struct DeterministicTunedModel{T,M<:Deterministic} <: MLJBase.Determinis
1616
acceleration::AbstractResource
1717
acceleration_resampling::AbstractResource
1818
check_measure::Bool
19+
cache::Bool
1920
end
2021

2122
mutable struct ProbabilisticTunedModel{T,M<:Probabilistic} <: MLJBase.Probabilistic
@@ -33,6 +34,7 @@ mutable struct ProbabilisticTunedModel{T,M<:Probabilistic} <: MLJBase.Probabilis
3334
acceleration::AbstractResource
3435
acceleration_resampling::AbstractResource
3536
check_measure::Bool
37+
cache::Bool
3638
end
3739

3840
const EitherTunedModel{T,M} =
@@ -55,7 +57,8 @@ MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true
5557
train_best=true,
5658
acceleration=default_resource(),
5759
acceleration_resampling=CPU1(),
58-
check_measure=true)
60+
check_measure=true,
61+
cache=true)
5962
6063
Construct a model wrapper for hyperparameter optimization of a
6164
supervised learner.
@@ -171,9 +174,13 @@ plus other key/value pairs specific to the `tuning` strategy.
171174
- `acceleration_resampling=CPU1()`: mode of parallelization for
172175
resampling
173176
174-
- `check_measure`: whether to check `measure` is compatible with the
177+
- `check_measure=true`: whether to check `measure` is compatible with the
175178
specified `model` and `operation`)
176179
180+
- `cache=true`: whether to cache model-specific representations of
181+
user-suplied data; set to `false` to conserve memory. Speed gains
182+
likely limited to the case `resampling isa Holdout`.
183+
177184
"""
178185
function TunedModel(; model=nothing,
179186
tuning=Grid(),
@@ -190,7 +197,8 @@ function TunedModel(; model=nothing,
190197
n=nothing,
191198
acceleration=default_resource(),
192199
acceleration_resampling=CPU1(),
193-
check_measure=true)
200+
check_measure=true,
201+
cache=true)
194202

195203
range === nothing && error("You need to specify `range=...`.")
196204
model == nothing && error("You need to specify model=... .\n"*
@@ -204,15 +212,17 @@ function TunedModel(; model=nothing,
204212
train_best, repeats, n,
205213
acceleration,
206214
acceleration_resampling,
207-
check_measure)
215+
check_measure,
216+
cache)
208217
elseif model isa Probabilistic
209218
tuned_model = ProbabilisticTunedModel(model, tuning, resampling,
210219
measure, weights, operation,
211220
range, selection_heuristic,
212221
train_best, repeats, n,
213222
acceleration,
214223
acceleration_resampling,
215-
check_measure)
224+
check_measure,
225+
cache)
216226
else
217227
error("Only `Deterministic` and `Probabilistic` "*
218228
"model types supported.")
@@ -432,7 +442,7 @@ function assemble_events(metamodels,
432442
ProgressMeter.updateProgress!(p)
433443
end
434444
end
435-
# One tresampling_machine per task
445+
# One resampling_machine per task
436446
machs = [resampling_machine,
437447
[machine(Resampler(
438448
model= resampling_machine.model.model,
@@ -442,8 +452,9 @@ function assemble_events(metamodels,
442452
operation = resampling_machine.model.operation,
443453
check_measure = resampling_machine.model.check_measure,
444454
repeats = resampling_machine.model.repeats,
445-
acceleration = resampling_machine.model.acceleration),
446-
resampling_machine.args...) for _ in 2:length(partitions)]...]
455+
acceleration = resampling_machine.model.acceleration,
456+
cache = resampling_machine.model.cache),
457+
resampling_machine.args...) for _ in 2:length(partitions)]...]
447458

448459
@sync for (i, parts) in enumerate(partitions)
449460
Threads.@spawn begin
@@ -566,7 +577,8 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
566577
operation = tuned_model.operation,
567578
check_measure = tuned_model.check_measure,
568579
repeats = tuned_model.repeats,
569-
acceleration = tuned_model.acceleration_resampling)
580+
acceleration = tuned_model.acceleration_resampling,
581+
cache = tuned_model.cache)
570582
resampling_machine = machine(resampler, data...)
571583
history, state = build(nothing, n, tuning, model, state,
572584
verbosity, acceleration, resampling_machine)

test/models.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@ include("models/simple_composite_model.jl")
2020
include("models/ensembles.jl")
2121

2222
end
23-

test/models/Constant.jl

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
## THE CONSTANT REGRESSOR
22

3+
const MMI = MLJModelInterface
34
export ConstantClassifier, ConstantRegressor,
4-
DeterministicConstantClassifier,
5-
ProbabilisticConstantClassifer
5+
DeterministicConstantClassifier,
6+
ProbabilisticConstantClassifer
67

7-
import MLJBase
88
import Distributions
99

1010
"""
@@ -14,49 +14,61 @@ A regressor that, for any new input pattern, predicts the univariate
1414
probability distribution best fitting the training target data. Use
1515
`predict_mean` to predict the mean value instead.
1616
"""
17-
struct ConstantRegressor{D} <: MLJBase.Probabilistic
18-
distribution_type::Type{D}
19-
end
17+
struct ConstantRegressor{D} <: MMI.Probabilistic end
2018

2119
function ConstantRegressor(; distribution_type=Distributions.Normal)
22-
model = ConstantRegressor(distribution_type)
20+
model = ConstantRegressor{distribution_type}()
2321
message = clean!(model)
2422
isempty(message) || @warn message
2523
return model
2624
end
2725

28-
function clean!(model::ConstantRegressor)
26+
function MMI.clean!(model::ConstantRegressor{D}) where D
2927
message = ""
30-
MLJBase.isdistribution(model.distribution_type) ||
28+
D <: Distributions.Sampleable ||
3129
error("$model.distribution_type is not a valid distribution_type.")
3230
return message
3331
end
3432

35-
function MLJBase.fit(::ConstantRegressor{D}, verbosity::Int, X, y) where D
33+
MMI.reformat(::ConstantRegressor, X) = (MMI.matrix(X),)
34+
MMI.reformat(::ConstantRegressor, X, y) = (MMI.matrix(X), y)
35+
MMI.selectrows(::ConstantRegressor, I, A) = (view(A, I, :),)
36+
MMI.selectrows(::ConstantRegressor, I, A, y) = (view(A, I, :), y[I])
37+
38+
function MMI.fit(::ConstantRegressor{D}, verbosity::Int, A, y) where D
3639
fitresult = Distributions.fit(D, y)
3740
cache = nothing
3841
report = NamedTuple()
3942
return fitresult, cache, report
4043
end
4144

42-
MLJBase.fitted_params(::ConstantRegressor, fitresult) = (target_distribution=fitresult,)
45+
MMI.fitted_params(::ConstantRegressor, fitresult) =
46+
(target_distribution=fitresult,)
4347

44-
MLJBase.predict(::ConstantRegressor, fitresult, Xnew) = fill(fitresult, nrows(Xnew))
48+
MMI.predict(::ConstantRegressor, fitresult, Xnew) =
49+
fill(fitresult, nrows(Xnew))
4550

4651
##
4752
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
4853
##
4954

50-
struct DeterministicConstantRegressor <: MLJBase.Deterministic end
55+
struct DeterministicConstantRegressor <: MMI.Deterministic end
5156

52-
function MLJBase.fit(::DeterministicConstantRegressor, verbosity::Int, X, y)
57+
function MMI.fit(::DeterministicConstantRegressor, verbosity::Int, X, y)
5358
fitresult = mean(y)
5459
cache = nothing
5560
report = NamedTuple()
5661
return fitresult, cache, report
5762
end
5863

59-
MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew) = fill(fitresult, nrows(Xnew))
64+
MMI.reformat(::DeterministicConstantRegressor, X) = (MMI.matrix(X),)
65+
MMI.reformat(::DeterministicConstantRegressor, X, y) = (MMI.matrix(X), y)
66+
MMI.selectrows(::DeterministicConstantRegressor, I, A) = (view(A, I, :),)
67+
MMI.selectrows(::DeterministicConstantRegressor, I, A, y) =
68+
(view(A, I, :), y[I])
69+
70+
MMI.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
71+
fill(fitresult, nrows(Xnew))
6072

6173
##
6274
## THE CONSTANT CLASSIFIER
@@ -71,39 +83,89 @@ training target data. So, `pdf(d, level)` is the proportion of levels
7183
in the training data coinciding with `level`. Use `predict_mode` to
7284
obtain the training target mode instead.
7385
"""
74-
struct ConstantClassifier <: MLJBase.Probabilistic end
86+
mutable struct ConstantClassifier <: MMI.Probabilistic
87+
testing::Bool
88+
bogus::Int
89+
end
90+
91+
ConstantClassifier(; testing=false, bogus=0) =
92+
ConstantClassifier(testing, bogus)
93+
94+
function MMI.reformat(model::ConstantClassifier, X)
95+
model.testing && @info "reformatting X"
96+
return (MMI.matrix(X),)
97+
end
98+
99+
function MMI.reformat(model::ConstantClassifier, X, y)
100+
model.testing && @info "reformatting X, y"
101+
return (MMI.matrix(X), y)
102+
end
103+
104+
function MMI.reformat(model::ConstantClassifier, X, y, w)
105+
model.testing && @info "reformatting X, y, w"
106+
return (MMI.matrix(X), y, w)
107+
end
108+
109+
function MMI.selectrows(model::ConstantClassifier, I, A)
110+
model.testing && @info "resampling X"
111+
return (view(A, I, :),)
112+
end
113+
114+
function MMI.selectrows(model::ConstantClassifier, I, A, y)
115+
model.testing && @info "resampling X, y"
116+
return (view(A, I, :), y[I])
117+
end
118+
119+
function MMI.selectrows(model::ConstantClassifier, I, A, y, ::Nothing)
120+
model.testing && @info "resampling X, y, nothing"
121+
return (view(A, I, :), y[I], nothing)
122+
end
123+
124+
function MMI.selectrows(model::ConstantClassifier, I, A, y, w)
125+
model.testing && @info "resampling X, y, nothing"
126+
return (view(A, I, :), y[I], w[I])
127+
end
75128

76129
# here `args` is `y` or `y, w`:
77-
function MLJBase.fit(::ConstantClassifier, verbosity::Int, X, y, w=nothing)
130+
function MMI.fit(::ConstantClassifier, verbosity::Int, A, y, w=nothing)
78131
fitresult = Distributions.fit(MLJBase.UnivariateFinite, y, w)
79132
cache = nothing
80133
report = NamedTuple
81134
return fitresult, cache, report
82135
end
83136

84-
MLJBase.fitted_params(::ConstantClassifier, fitresult) = (target_distribution=fitresult,)
137+
MMI.fitted_params(::ConstantClassifier, fitresult) =
138+
(target_distribution=fitresult,)
85139

86-
MLJBase.predict(::ConstantClassifier, fitresult, Xnew) = fill(fitresult, nrows(Xnew))
140+
MMI.predict(::ConstantClassifier, fitresult, Xnew) =
141+
fill(fitresult, nrows(Xnew))
87142

88143
##
89144
## DETERMINISTIC CONSTANT CLASSIFIER (FOR TESTING)
90145
##
91146

92-
struct DeterministicConstantClassifier <: MLJBase.Deterministic end
147+
struct DeterministicConstantClassifier <: MMI.Deterministic end
93148

94-
function MLJBase.fit(::DeterministicConstantClassifier, verbosity::Int, X, y)
149+
function MMI.fit(::DeterministicConstantClassifier, verbosity::Int, X, y)
95150
# dump missing target values and make into a regular array:
96-
fitresult = mode(skipmissing(y) |> collect) # a CategoricalValue or CategoricalString
151+
fitresult = mode(skipmissing(y) |> collect) # a CategoricalValue
97152
cache = nothing
98153
report = NamedTuple()
99154
return fitresult, cache, report
100155
end
101156

102-
MLJBase.predict(::DeterministicConstantClassifier, fitresult, Xnew) = fill(fitresult, nrows(Xnew))
157+
MMI.reformat(::DeterministicConstantClassifier, X) = (MMI.matrix(X),)
158+
MMI.reformat(::DeterministicConstantClassifier, X, y) = (MMI.matrix(X), y)
159+
MMI.selectrows(::DeterministicConstantClassifier, I, A) = (view(A, I, :),)
160+
MMI.selectrows(::DeterministicConstantClassifier, I, A, y) =
161+
(view(A, I, :), y[I])
103162

104-
##
105-
## METADATA
106-
##
163+
MMI.predict(::DeterministicConstantClassifier, fitresult, Xnew) =
164+
fill(fitresult, nrows(Xnew))
165+
166+
#
167+
# METADATA
168+
#
107169

108170
metadata_pkg.((ConstantRegressor, ConstantClassifier,
109171
DeterministicConstantRegressor, DeterministicConstantClassifier),
@@ -115,29 +177,29 @@ metadata_pkg.((ConstantRegressor, ConstantClassifier,
115177
is_wrapper=false)
116178

117179
metadata_model(ConstantRegressor,
118-
input=MLJBase.Table(MLJBase.Scientific),
119-
target=AbstractVector{MLJBase.Continuous},
180+
input=MMI.Table,
181+
target=AbstractVector{MMI.Continuous},
120182
weights=false,
121183
descr="Constant regressor (Probabilistic).",
122184
path="MLJModels.ConstantRegressor")
123185

124186
metadata_model(DeterministicConstantRegressor,
125-
input=MLJBase.Table(MLJBase.Scientific),
126-
target=AbstractVector{MLJBase.Continuous},
187+
input=MMI.Table,
188+
target=AbstractVector{MMI.Continuous},
127189
weights=false,
128190
descr="Constant regressor (Deterministic).",
129191
path="MLJModels.DeterministicConstantRegressor")
130192

131193
metadata_model(ConstantClassifier,
132-
input=MLJBase.Table(MLJBase.Scientific),
133-
target=AbstractVector{<:MLJBase.Finite},
194+
input=MMI.Table,
195+
target=AbstractVector{<:MMI.Finite},
134196
weights=true,
135197
descr="Constant classifier (Probabilistic).",
136198
path="MLJModels.ConstantClassifier")
137199

138200
metadata_model(DeterministicConstantClassifier,
139-
input=MLJBase.Table(MLJBase.Scientific),
140-
target=AbstractVector{<:MLJBase.Finite},
201+
input=MMI.Table,
202+
target=AbstractVector{<:MMI.Finite},
141203
weights=false,
142204
descr="Constant classifier (Deterministic).",
143205
path="MLJModels.DeterministicConstantClassifier")

test/models/DecisionTree.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using MLJScientificTypes
44

55
using CategoricalArrays
66

7-
import DecisionTree
7+
import DecisionTree
88

99
## DESCRIPTIONS
1010

@@ -50,7 +50,7 @@ from the DecisionTree.jl algorithm).
5050
5151
For post-fit pruning, set `post-prune=true` and set
5252
`min_purity_threshold` appropriately. Other hyperparameters as per
53-
package documentation cited above.
53+
package documentation cited above.
5454
5555
5656
"""
@@ -214,4 +214,3 @@ metadata_model(DecisionTreeRegressor,
214214
target=AbstractVector{MLJBase.Continuous},
215215
weights=false,
216216
descr=DTR_DESCR)
217-

test/strategies/grid.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ end
147147
resampling=holdout, measure=rms,
148148
range=r)
149149

150-
MLJBase.info_dict(tuned_model)
151-
152150
tuned = machine(tuned_model, X, y)
153151

154152
fit!(tuned, verbosity=0)

0 commit comments

Comments
 (0)