Skip to content

Commit ebf0983

Browse files
authored
Merge pull request #170 from JuliaAI/dev
For a 0.7.0 release
2 parents 31f0255 + 32b502e commit ebf0983

File tree

11 files changed

+223
-102
lines changed

11 files changed

+223
-102
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
version:
22-
- '1.3'
22+
- '1.6'
2323
- '1'
2424
os:
2525
- ubuntu-latest

Project.toml

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.6.16"
4+
version = "0.7.0"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
@@ -17,27 +17,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1717
ComputationalResources = "0.3"
1818
Distributions = "0.22,0.23,0.24, 0.25"
1919
LatinHypercubeSampling = "1.7.2"
20-
MLJBase = "0.18.19, 0.19"
21-
MLJModelInterface = "0.4.1, 1.1.1"
20+
MLJBase = "0.20"
2221
ProgressMeter = "1.7.1"
2322
RecipesBase = "0.8,0.9,1"
24-
julia = "1.3"
25-
26-
[extras]
27-
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
28-
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
29-
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
30-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
31-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
32-
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33-
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
34-
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
35-
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
36-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
37-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
38-
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
39-
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
40-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
41-
42-
[targets]
43-
test = ["CategoricalArrays", "DecisionTree", "Distances", "Distributions", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "NearestNeighbors", "ScientificTypes", "StableRNGs", "Statistics", "StatsBase", "Tables", "Test"]
23+
julia = "1.6"

src/MLJTuning.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ export learning_curve!, learning_curve
2121
import MLJBase
2222
using MLJBase
2323
import MLJBase: Bounded, Unbounded, DoublyUnbounded,
24-
LeftUnbounded, RightUnbounded, _process_accel_settings, chunks
24+
LeftUnbounded, RightUnbounded, _process_accel_settings, chunks,
25+
restore, save
2526
using RecipesBase
2627
using Distributed
2728
import Distributions

src/learning_curves.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ Other key-word options are documented at [`TunedModel`](@ref).
9292
learning_curve(mach::Machine{<:Supervised}; kwargs...) =
9393
learning_curve(mach.model, mach.args...; kwargs...)
9494

95-
# for backwards compatibility
96-
function learning_curve!(mach::Machine{<:Supervised}; kwargs...)
97-
Base.depwarn("`learning_curve!` is deprecated, use `learning_curve` instead. ",
98-
Core.Typeof(learning_curve!).name.mt.name)
99-
learning_curve(mach; kwargs...)
100-
end
101-
10295
function learning_curve(model::Supervised, args...;
10396
resolution=30,
10497
resampling=Holdout(),
@@ -299,8 +292,12 @@ end
299292

300293
n_threads = Threads.nthreads()
301294
if n_threads == 1
302-
return _tuning_results(rngs, CPU1(),
303-
tuned, rng_name, verbosity)
295+
return _tuning_results(rngs,
296+
CPU1(),
297+
tuned,
298+
rows,
299+
rng_name,
300+
verbosity)
304301
end
305302

306303
old_rng = recursive_getproperty(tuned.model.model, rng_name)

src/serialization.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
MLJModelInterface.save(::MLJTuning.EitherTunedModel, fitresult::Machine) =
2+
serializable(fitresult)
3+
4+
function MLJModelInterface.restore(::MLJTuning.EitherTunedModel, fitresult)
5+
fitresult.fitresult = restore(fitresult.model, fitresult.fitresult)
6+
return fitresult
7+
end

src/tuned_models.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ const EitherTunedModel{T,M} =
6767
#todo update:
6868
"""
6969
tuned_model = TunedModel(; model=<model to be mutated>,
70-
tuning=Grid(),
70+
tuning=RandomSearch(),
7171
resampling=Holdout(),
7272
range=nothing,
7373
measure=nothing,
@@ -173,7 +173,10 @@ plus other key/value pairs specific to the `tuning` strategy.
173173
- `models`: Alternatively, an iterator of MLJ models to be explicitly
174174
evaluated. These may have varying types.
175175
176-
- `tuning=Grid()`: tuning strategy to be applied (eg, `RandomSearch()`)
176+
- `tuning=RandomSearch()`: tuning strategy to be applied (eg, `Grid()`). See
177+
the [Tuning
178+
Models](https://alan-turing-institute.github.io/MLJ.jl/dev/tuning_models/#Tuning-Models)
179+
section of the MLJ manual for a complete list of options.
177180
178181
- `resampling=Holdout()`: resampling strategy (eg, `Holdout()`, `CV()`),
179182
`StratifiedCV()`) to be applied in performance evaluations
@@ -253,7 +256,7 @@ function TunedModel(; model=nothing,
253256
throw(ERR_NEED_EXPLICIT)
254257
end
255258
else
256-
tuning === nothing && (tuning = Grid())
259+
tuning === nothing && (tuning = RandomSearch())
257260
end
258261

259262
# either a `model` is specified or we are in the case

test/Project.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[deps]
2+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
3+
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
4+
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
5+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
6+
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
7+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
LatinHypercubeSampling = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
11+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
12+
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
13+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
16+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
17+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
18+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
19+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
20+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
21+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22+
23+
[compat]
24+
CategoricalArrays = "0.10"
25+
ComputationalResources = "0.3"
26+
DecisionTree = "0.10"
27+
Distances = "0.10"
28+
Distributions = "0.25"
29+
MLJBase = "0.20"
30+
MLJModelInterface = "1.3"
31+
MultivariateStats = "0.9"
32+
NearestNeighbors = "0.4"
33+
ScientificTypes = "3.0"
34+
StableRNGs = "1.0"
35+
StatsBase = "0.33"
36+
Tables = "1.6"

test/learning_curves.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,5 @@ end
206206

207207
end
208208

209-
@testset "deprecation of learning_curve!" begin
210-
atom = KNNRegressor()
211-
mach = machine(atom, X, y)
212-
r = range(atom, :K, lower=1, upper=2)
213-
@test_deprecated learning_curve!(mach;
214-
range=r,
215-
measure=LPLoss(),
216-
verbosity=0)
217-
218-
end
219-
220209
end # module
221210
true

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ end
5959
@test include("learning_curves.jl")
6060
end
6161

62+
@testset "Serialization" begin
63+
@test include("serialization.jl")
64+
end
65+
6266
# @testset "julia bug" begin
6367
# @test include("julia_bug.jl")
6468
# end

test/serialization.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
2+
module TestSerialization
3+
4+
using Test
5+
using MLJBase
6+
using Serialization
7+
using MLJTuning
8+
using ..Models
9+
10+
function test_args(mach)
11+
# Check source nodes are empty if any
12+
for arg in mach.args
13+
if arg isa Source
14+
@test arg == source()
15+
end
16+
end
17+
end
18+
19+
function test_data(mach)
20+
@test !isdefined(mach, :old_rows)
21+
@test !isdefined(mach, :data)
22+
@test !isdefined(mach, :resampled_data)
23+
@test !isdefined(mach, :cache)
24+
end
25+
26+
function generic_tests(mach₁, mach₂)
27+
test_args(mach₂)
28+
test_data(mach₂)
29+
@test mach₂.state == -1
30+
for field in (:frozen, :model, :old_model, :old_upstream_state, :fit_okay)
31+
@test getfield(mach₁, field) == getfield(mach₂, field)
32+
end
33+
end
34+
35+
36+
@testset "Test TunedModel" begin
37+
filename = "tuned_model.jls"
38+
X, y = make_regression(100)
39+
base_model = DecisionTreeRegressor()
40+
tuned_model = TunedModel(
41+
model=base_model,
42+
tuning=Grid(),
43+
range=[range(base_model, :min_samples_split, values=[2,3,4])],
44+
)
45+
mach = machine(tuned_model, X, y)
46+
fit!(mach, rows=1:50, verbosity=0)
47+
smach = MLJBase.serializable(mach)
48+
@test smach.fitresult isa Machine
49+
@test smach.report == mach.report
50+
generic_tests(mach, smach)
51+
52+
Serialization.serialize(filename, smach)
53+
smach = Serialization.deserialize(filename)
54+
MLJBase.restore!(smach)
55+
56+
@test MLJBase.predict(smach, X) == MLJBase.predict(mach, X)
57+
@test fitted_params(smach) isa NamedTuple
58+
@test report(smach) == report(mach)
59+
60+
rm(filename)
61+
62+
# End to end
63+
MLJBase.save(filename, mach)
64+
smach = machine(filename)
65+
@test predict(smach, X) == predict(mach, X)
66+
67+
rm(filename)
68+
69+
end
70+
71+
end
72+
73+
true

0 commit comments

Comments
 (0)