Skip to content

Commit b3d22ee

Browse files
committed
Fix extensions
1 parent 4bf1e07 commit b3d22ee

File tree

4 files changed

+79
-33
lines changed

4 files changed

+79
-33
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8383
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
8484
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
8585
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
86+
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
8687
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
8788

8889
[targets]
89-
test = ["Test", "DataFrames", "Random", "MLJ", "MultiData", "Markdown", "InteractiveUtils", "BenchmarkTools", "MLJBase", "XGBoost", "DecisionTree", "MLJDecisionTreeInterface", "SoleData"]
90+
test = ["Test", "DataFrames", "Random", "MLJ", "MLJXGBoostInterface", "MultiData", "Markdown", "InteractiveUtils", "BenchmarkTools", "MLJBase", "XGBoost", "DecisionTree", "MLJDecisionTreeInterface", "SoleData"]

ext/DecisionTreeExt.jl

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,56 @@ module DecisionTreeExt
22

33
using SoleModels
44
import SoleModels: solemodel
5+
import SoleModels: alphabet
56

67
import DecisionTree as DT
78

9+
function get_condition(featid, featval, featurenames)
10+
test_operator = (<)
11+
# @show fieldnames(typeof(tree))
12+
feature = !isnothing(featurenames) ? VariableValue(featurenames[featid]) : VariableValue(featid)
13+
return ScalarCondition(feature, test_operator, featval)
14+
end
15+
16+
function SoleModels.alphabet(
17+
model::Union{
18+
DT.Ensemble,
19+
DT.InfoNode,
20+
DT.Node,
21+
DT.Leaf,
22+
},
23+
args...;
24+
kwargs...
25+
)
26+
27+
function _alphabet!(a::Vector, model::DT.Ensemble, args...; kwargs...)
28+
map(t -> _alphabet!(a, t, args...; kwargs...), model.trees)
29+
return a
30+
end
31+
32+
function _alphabet!(a::Vector, model::DT.InfoNode, args...; kwargs...)
33+
_alphabet!(a, model.left, args...; kwargs...)
34+
_alphabet!(a, model.right, args...; kwargs...)
35+
return a
36+
end
37+
38+
function _alphabet!(a::Vector, model::DT.Node, args...;
39+
featurenames = true,
40+
kwargs...
41+
)
42+
featurenames = featurenames == true ? model.info.featurenames : featurenames
43+
push!(a, Atom(get_condition(model.featid, model.featval, featurenames)))
44+
return a
45+
end
46+
47+
function _alphabet!(a::Vector, model::DT.Leaf, args...; kwargs...)
48+
return a
49+
end
50+
51+
return SoleData.scalaralphabet(_alphabet!(Atom{ScalarCondition}[], model, args...; kwargs...))
52+
end
53+
54+
855
function SoleModels.solemodel(
956
model::DT.Ensemble,
1057
classlabels = nothing,
@@ -116,10 +163,7 @@ end
116163

117164
function SoleModels.solemodel(tree::DT.Node; replace_classlabels = nothing, featurenames = nothing, keep_condensed = false)
118165
keep_condensed && error("Cannot keep condensed DecisionTree.Node.")
119-
test_operator = (<)
120-
# @show fieldnames(typeof(tree))
121-
feature = !isnothing(featurenames) ? VariableValue(featurenames[tree.featid]) : VariableValue(tree.featid)
122-
cond = ScalarCondition(feature, test_operator, tree.featval)
166+
cond = get_condition(tree.featid, tree.featval, featurenames)
123167
antecedent = Atom(cond)
124168
lefttree = SoleModels.solemodel(tree.left; replace_classlabels, featurenames)
125169
righttree = SoleModels.solemodel(tree.right; replace_classlabels, featurenames)

ext/XGBoostExt.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
module XGBoostExt
22

3+
using SoleModels
34
using XGBoost
45

5-
import Sole: alphabet, solemodel
6-
7-
# TODO fix and test
8-
# function alphabet(model::XGBoost.Booster; kwargs...)
9-
# function _alphabet!(a::Vector, model::XGBoost.Booster; kwargs...)
10-
# return a
11-
# end
12-
# function _alphabet!(a::Vector, tree::XGBoost.Node; kwargs...)
13-
# # Base case: if it's a leaf node
14-
# if length(tree.children) == 0
15-
# return a
16-
# end
17-
18-
# # Recursive case: split node
19-
# feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
20-
# condition = ScalarCondition(feature, (<), tree.split_condition)
21-
# push!(a, condition)
22-
# if length(tree.children) == 2
23-
# _alphabet!(a, tree.children[1]; with_stats, kwargs...)
24-
# _alphabet!(a, tree.children[2]; with_stats, kwargs...)
25-
# else
26-
# error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
27-
# end
28-
# return a
29-
# end
30-
# _alphabet!([], model; kwargs...)
31-
# end
6+
import SoleModels: alphabet, solemodel
7+
8+
function alphabet(model::XGBoost.Booster; kwargs...)
9+
error("TODO fix and test.")
10+
function _alphabet!(a::Vector, model::XGBoost.Booster; kwargs...)
11+
return a
12+
end
13+
function _alphabet!(a::Vector, tree::XGBoost.Node; kwargs...)
14+
# Base case: if it's a leaf node
15+
if length(tree.children) == 0
16+
return a
17+
end
18+
19+
# Recursive case: split node
20+
feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
21+
condition = ScalarCondition(feature, (<), tree.split_condition)
22+
push!(a, condition)
23+
if length(tree.children) == 2
24+
_alphabet!(a, tree.children[1]; with_stats, kwargs...)
25+
_alphabet!(a, tree.children[2]; with_stats, kwargs...)
26+
else
27+
error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
28+
end
29+
return a
30+
end
31+
_alphabet!(Atom{ScalarCondition}[], model; kwargs...)
32+
end
3233

3334

3435
# TODO fix and test. Problem: where are the tree weights? How do I write this in the multi-class case?

test/XGBoostExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ println("Test Accuracy: $acc")
3838

3939

4040

41-
using Sole
41+
using SoleModels
4242

4343
@test_nowarn alphabet(fitted_params(mach).fitresult[1])
4444

0 commit comments

Comments
 (0)