Skip to content

Commit 9fe1f52

Browse files
authored
Merge pull request #214 from JuliaAI/dev
For a 0.8.4 release
2 parents 4f1dd71 + c597f27 commit 9fe1f52

File tree

6 files changed

+82
-129
lines changed

6 files changed

+82
-129
lines changed

.github/codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ coverage:
33
project:
44
default:
55
threshold: 0.5%
6+
removed_code_behavior: fully_covered_patch
67
patch:
78
default:
89
target: 80%

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTuning"
22
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
4-
version = "0.8.3"
4+
version = "0.8.4"
55

66
[deps]
77
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"

src/tuned_models.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,22 @@ function MLJBase.training_losses(tuned_model::EitherTunedModel, _report)
857857
return ret
858858
end
859859

860+
## Support for Feature Importances
861+
function MLJBase.reports_feature_importances(::Type{<:EitherTunedModel{<:Any,M}}) where {M}
862+
return MLJBase.reports_feature_importances(M)
863+
end
864+
865+
function MLJBase.reports_feature_importances(model::EitherTunedModel)
866+
return MLJBase.reports_feature_importances(model.model)
867+
end # This is needed in some cases (e.g tuning a `Pipeline`)
868+
869+
function MLJBase.feature_importances(::EitherTunedModel, fitresult, report)
870+
# fitresult here is a machine created using the best_model obtained
871+
# from the tuning process.
872+
# The line below will return `nothing` when the model being tuned doesn't
873+
# support feature_importances.
874+
return MLJBase.feature_importances(fitresult)
875+
end
860876

861877
## METADATA
862878

test/models/DecisionTree.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,11 @@ function MLJBase.fit(model::DecisionTreeClassifier, verbosity::Int, X, y)
9595
#> empty values):
9696

9797
cache = nothing
98-
report = (classes_seen=classes_seen,
99-
print_tree=TreePrinter(tree))
98+
report = (
99+
classes_seen=classes_seen,
100+
print_tree=TreePrinter(tree),
101+
features=collect(Tables.columnnames(Tables.columns(X)))
102+
)
100103

101104
return fitresult, cache, report
102105
end
@@ -134,6 +137,17 @@ function MLJBase.predict(model::DecisionTreeClassifier
134137
for i in 1:size(y_probabilities, 1)]
135138
end
136139

140+
MLJBase.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true
141+
142+
function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report)
143+
features = report.features
144+
fi = DecisionTree.impurity_importance(first(fitresult), normalize=true)
145+
fi_pairs = Pair.(features, fi)
146+
# sort descending
147+
sort!(fi_pairs, by= x->-x[2])
148+
149+
return fi_pairs
150+
end
137151

138152
## REGRESSOR
139153

test/schizo.md

Lines changed: 0 additions & 113 deletions
This file was deleted.

test/tuned_models.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@ using Random
1313
Random.seed!(1234*myid())
1414
using .TestUtilities
1515

16-
N = 30
17-
x1 = rand(N);
18-
x2 = rand(N);
19-
x3 = rand(N);
20-
X = (; x1, x2, x3);
21-
y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);
22-
23-
m(K) = KNNRegressor(; K)
24-
r = [m(K) for K in 13:-1:2]
25-
26-
# TODO: replace the above with the line below and post an issue on
27-
# the failure (a bug in Distributed, I reckon):
28-
# r = (m(K) for K in 13:-1:2)
16+
begin
17+
N = 30
18+
x1 = rand(N);
19+
x2 = rand(N);
20+
x3 = rand(N);
21+
X = (; x1, x2, x3);
22+
y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);
23+
24+
m(K) = KNNRegressor(; K)
25+
r = [m(K) for K in 13:-1:2]
26+
27+
Xtree, yhat = @load_iris
28+
trees = [DecisionTreeClassifier(pruning_purity = rand()) for _ in 13:-1:2]
29+
30+
# TODO: replace the above with the line below and post an issue on
31+
# the failure (a bug in Distributed, I reckon):
32+
# r = (m(K) for K in 13:-1:2)
33+
end
2934

3035
@testset "constructor" begin
3136
@test_throws(MLJTuning.ERR_SPECIFY_RANGE,
@@ -105,6 +110,10 @@ end
105110
@test _report.best_model == collect(r)[best_index]
106111
@test _report.history[5] == MLJTuning.delete(history[5], :metadata)
107112

113+
# feature_importances:
114+
# This should return nothing as `KNNRegressor` doesn't support feature_importances
115+
@test feature_importances(tm, fitresult, _report) === nothing
116+
108117
# training_losses:
109118
losses = training_losses(tm, _report)
110119
@test all(eachindex(losses)) do i
@@ -146,6 +155,32 @@ end
146155
@test results4 results
147156
end
148157

158+
@testset_accelerated "Feature Importances" accel begin
159+
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
160+
tm0 = TunedModel(
161+
model = trees[1],
162+
measure = rms,
163+
tuning = Grid(),
164+
resampling = CV(nfolds = 5),
165+
range = range(
166+
trees[1], :max_depth, values = 1:10
167+
)
168+
)
169+
@test reports_feature_importances(typeof(tm0))
170+
tm = TunedModel(
171+
models = trees,
172+
resampling = CV(nfolds=2),
173+
measures = cross_entropy,
174+
acceleration = CPU1(),
175+
acceleration_resampling = accel
176+
)
177+
@test reports_feature_importances(tm)
178+
fitresult, _, report = MLJBase.fit(tm, 0, Xtree, yhat)
179+
features = first.(feature_importances(tm, fitresult, report))
180+
@test Set(features) == Set(keys(Xtree))
181+
182+
end
183+
149184
@testset_accelerated(
150185
"under/over supply of models",
151186
accel,

0 commit comments

Comments
 (0)