Skip to content

Commit 1528d0d

Browse files
committed
Add skeleton for XGBoostExt. Fix utilities.Introducedensembles.
1 parent cdd301a commit 1528d0d

File tree

14 files changed

+537
-177
lines changed

14 files changed

+537
-177
lines changed

Project.toml

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,19 @@ version = "0.9.0"
66

77
[deps]
88
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
9-
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
109
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
1110
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
12-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1311
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1412
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1513
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1614
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1715
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
16+
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1817
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
1918
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2019
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2120
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2221
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
23-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2422
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2523
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
2624
SoleBase = "4475fa32-7023-44a0-aa70-4813b230e492"
@@ -34,15 +32,18 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
3432

3533
[weakdeps]
3634
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
35+
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
3736
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
37+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3838
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
3939
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
4040

4141
[extensions]
4242
DecisionTreeExt = "DecisionTree"
43-
MLJExt = "MLJ"
4443
MLJDecisionTreeInterfaceExt = "MLJDecisionTreeInterface"
44+
MLJExt = "MLJ"
4545
MLJModelInterfaceExt = "MLJModelInterface"
46+
XGBoostExt = "XGBoost"
4647

4748
[compat]
4849
AbstractTrees = "0.4"
@@ -51,11 +52,16 @@ CSV = "0.10"
5152
CategoricalArrays = "0.10"
5253
DataFrames = "1"
5354
DataStructures = "0.18"
55+
DecisionTree = "0.12"
5456
FillArrays = "1"
5557
FunctionWrappers = "1"
5658
Graphs = "1.8"
5759
HTTP = "1.9"
60+
IterTools = "1"
5861
Lazy = "0.15.1"
62+
MLJ = "0.20"
63+
MLJBase = "1.6"
64+
MLJDecisionTreeInterface = "0.4"
5965
MLJModelInterface = "1.8.0"
6066
PrettyTables = "2.2"
6167
ProgressMeter = "1"
@@ -65,6 +71,7 @@ Revise = "3"
6571
SoleBase = "0.12"
6672
SoleData = "0.15, 0.16"
6773
SoleLogics = "0.11"
74+
XGBoost = "2.5"
6875
StatsBase = "0.30 - 0.34"
6976
Suppressor = "0.2"
7077
Tables = "1"
@@ -73,11 +80,15 @@ ZipFile = "0.10"
7380
julia = "1"
7481

7582
[extras]
83+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
84+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7685
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7786
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
7887
MultiData = "8cc5100c-b3d1-4f82-90cb-0ea93d317aba"
7988
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
89+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
90+
SoleData = "123f1ae1-6307-4526-ab5b-aab3a92a2b8c"
8091
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8192

8293
[targets]
83-
test = ["Test", "Markdown", "InteractiveUtils", "PlutoUI", "MultiData"]
94+
test = ["Test", "DataFrames", "Random", "MLJ", "MultiData", "Markdown", "InteractiveUtils", "BenchmarkTools", "MLJBase", "XGBoost", "DecisionTree", "MLJDecisionTreeInterface", "SoleData"]

ext/DecisionTreeExt.jl

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,80 @@
11
module DecisionTreeExt
22

33
using SoleModels
4+
import SoleModels: solemodel
45

56
import DecisionTree as DT
67

7-
function SoleModels.solemodel(model::DT.Ensemble, args...; kwargs...)
8-
return SoleModels.DecisionForest(map(t -> SoleModels.DecisionTree(SoleModels.solemodel(t, args...; kwargs...)), model.trees))
8+
function SoleModels.solemodel(
9+
model::DT.Ensemble,
10+
classlabels = nothing,
11+
featurenames = nothing,
12+
args...;
13+
keep_condensed = true,
14+
kwargs...
15+
)
16+
if isnothing(classlabels)
17+
error("Please, provide classlabels argument, as in solemodel(forest, classlabels; kwargs...). If your forest was trained via MLJ, use `classlabels = (mach).fitresult[2][sortperm((mach).fitresult[3])]`.")
18+
end
19+
if keep_condensed
20+
info = (;
21+
apply_preprocess=(y -> UInt32(findfirst(x -> x == y, classlabels))),
22+
apply_postprocess=(y -> classlabels[y]),
23+
)
24+
keep_condensed = !keep_condensed
25+
# O = UInt32
26+
else
27+
info = (;)
28+
# O = UInt32
29+
end
30+
trees = map(t -> SoleModels.solemodel(t, args...; keep_condensed, featurenames, kwargs...), model.trees)
31+
# trees = map(t -> let b = SoleModels.solemodel(t, args...; keep_condensed, featurenames, kwargs...); SoleModels.DecisionTree(b,
32+
# (;
33+
# supporting_predictions=b.info[:supporting_predictions],
34+
# supporting_labels=b.info[:supporting_labels],
35+
# )
36+
# ) end, model.trees)
37+
38+
if !isnothing(featurenames)
39+
info = merge(info, (; featurenames=featurenames, ))
40+
end
41+
42+
info = merge(info, (;
43+
supporting_predictions=vcat([t.info[:supporting_predictions] for t in trees]...),
44+
supporting_labels=vcat([t.info[:supporting_labels] for t in trees]...),
45+
)
46+
)
47+
48+
if !isnothing(classlabels)
49+
O = eltype(classlabels)
50+
# O = eltype(levels(classnames))
51+
else
52+
O = nothing
53+
end
54+
55+
if isnothing(O)
56+
m = DecisionEnsemble(trees, info)
57+
else
58+
m = DecisionEnsemble{O}(trees, info)
59+
end
60+
return m
961
end
1062

11-
function SoleModels.solemodel(tree::DT.InfoNode, keep_condensed = false; use_featurenames = true, kwargs...)
63+
function SoleModels.solemodel(tree::DT.InfoNode; keep_condensed = true, featurenames = true, classlabels = tree.info.classlabels, kwargs...)
1264
# @show fieldnames(typeof(tree))
13-
use_featurenames = use_featurenames ? tree.info.featurenames : false
65+
featurenames = featurenames == true ? tree.info.featurenames : featurenames
66+
1467
root, info = begin
1568
if keep_condensed
16-
root = SoleModels.solemodel(tree.node; use_featurenames = use_featurenames, kwargs...)
69+
root = SoleModels.solemodel(tree.node; featurenames, kwargs...)
1770
info = (;
18-
apply_preprocess=(y -> UInt32(findfirst(x -> x == y, tree.info.classlabels))),
19-
apply_postprocess=(y -> tree.info.classlabels[y]),
71+
apply_preprocess=(y -> UInt32(findfirst(x -> x == y, classlabels))),
72+
apply_postprocess=(y -> classlabels[y]),
2073
)
74+
keep_condensed = !keep_condensed
2175
root, info
2276
else
23-
root = SoleModels.solemodel(tree.node; replace_classlabels = tree.info.classlabels, use_featurenames = use_featurenames, kwargs...)
77+
root = SoleModels.solemodel(tree.node; replace_classlabels = classlabels, featurenames, kwargs...)
2478
info = (;)
2579
root, info
2680
end
@@ -33,7 +87,19 @@ function SoleModels.solemodel(tree::DT.InfoNode, keep_condensed = false; use_fea
3387
supporting_labels=root.info[:supporting_labels],
3488
)
3589
)
36-
return DecisionTree(root, info)
90+
91+
# if !isnothing(classlabels)
92+
# O = eltype(classlabels)
93+
# else
94+
# O = nothing
95+
# end
96+
97+
# if isnothing(O)
98+
dt = DecisionTree(root, info)
99+
# else
100+
# dt = DecisionTree{O}(root, info)
101+
# end
102+
return dt
37103
end
38104

39105
# function SoleModels.solemodel(tree::DT.Root)
@@ -48,22 +114,24 @@ end
48114
# return DecisionTree(root, info)
49115
# end
50116

51-
function SoleModels.solemodel(tree::DT.Node; replace_classlabels = nothing, use_featurenames = false)
117+
function SoleModels.solemodel(tree::DT.Node; replace_classlabels = nothing, featurenames = nothing, keep_condensed = false)
118+
keep_condensed && error("Cannot keep condensed DecisionTree.Node.")
52119
test_operator = (<)
53120
# @show fieldnames(typeof(tree))
54-
feature = (use_featurenames != false) ? VariableValue(use_featurenames[tree.featid]) : VariableValue(tree.featid)
121+
feature = !isnothing(featurenames) ? VariableValue(featurenames[tree.featid]) : VariableValue(tree.featid)
55122
cond = ScalarCondition(feature, test_operator, tree.featval)
56123
antecedent = Atom(cond)
57-
lefttree = SoleModels.solemodel(tree.left; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
58-
righttree = SoleModels.solemodel(tree.right; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
124+
lefttree = SoleModels.solemodel(tree.left; replace_classlabels, featurenames)
125+
righttree = SoleModels.solemodel(tree.right; replace_classlabels, featurenames)
59126
info = (;
60127
supporting_predictions = [lefttree.info[:supporting_predictions]..., righttree.info[:supporting_predictions]...],
61128
supporting_labels = [lefttree.info[:supporting_labels]..., righttree.info[:supporting_labels]...],
62129
)
63130
return Branch(antecedent, lefttree, righttree, info)
64131
end
65132

66-
function SoleModels.solemodel(tree::DT.Leaf; replace_classlabels = nothing, use_featurenames = false)
133+
function SoleModels.solemodel(tree::DT.Leaf; replace_classlabels = nothing, featurenames = nothing, keep_condensed = false)
134+
keep_condensed && error("Cannot keep condensed DecisionTree.Node.")
67135
# @show fieldnames(typeof(tree))
68136
prediction = tree.majority
69137
labels = tree.values

ext/XGBoostExt.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
module XGBoostExt
2+
3+
using XGBoost
4+
5+
import Sole: alphabet, solemodel
6+
7+
function alphabet(model::XGBoost.Booster; kwargs...)
8+
function _alphabet!(a::Vector, model::XGBoost.Booster; kwargs...)
9+
return a
10+
end
11+
function _alphabet!(a::Vector, tree::XGBoost.Node; kwargs...)
12+
# Base case: if it's a leaf node
13+
if length(tree.children) == 0
14+
return a
15+
end
16+
17+
# Recursive case: split node
18+
feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
19+
condition = ScalarCondition(feature, (<), tree.split_condition)
20+
push!(a, condition)
21+
if length(tree.children) == 2
22+
_alphabet!(a, tree.children[1]; with_stats, kwargs...)
23+
_alphabet!(a, tree.children[2]; with_stats, kwargs...)
24+
else
25+
error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
26+
end
27+
return a
28+
end
29+
_alphabet!([], model; kwargs...)
30+
end
31+
32+
33+
# TODO fix and test. Problem: where are the tree weights? How do I write this in the multi-class case?
34+
35+
# # Convert an XGBoost.Booster to a Sole Ensemble
36+
# function solemodel(model::XGBoost.Booster; with_stats::Bool = true, kwargs...)
37+
# # Extract weights (global scaling factors for trees, if any)
38+
# weights = nothing # XGBoost trees usually don't have individual weights, but modify here if needed.
39+
40+
# model.params[:objective] == "multi:softprob" || error("Unexpected objective encountered: $(model.params[:objective]).")
41+
# isempty(model.feature_names) || error("Unexpected objective encountered: $(model.params[:objective]).")
42+
43+
# # Convert all trees into Sole models
44+
# trees = [solemodel(tree; with_stats, kwargs...) for tree in XGBoost.trees(model; with_stats = with_stats)]
45+
46+
# # Return a Sole Ensemble
47+
# return Sole.Ensemble(trees, weights)
48+
# end
49+
50+
# # Convert a single XGBoost tree into a Sole tree
51+
# function solemodel(tree::XGBoost.Node; with_stats::Bool = true, kwargs...)
52+
# function _makeleaf(value)
53+
# SoleModels.ConstantModel(value, (; supporting_predictions=[value]))
54+
# end
55+
56+
# # Base case: if it's a leaf node
57+
# if length(tree.children) == 0
58+
# return _makeleaf(tree.leaf)
59+
# end
60+
61+
# # Recursive case: split node
62+
# feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
63+
# condition = ScalarCondition(feature, (<), tree.split_condition)
64+
# antecedent = Atom(condition)
65+
66+
# # Recursively convert left and right branches
67+
68+
# if length(tree.children) == 2
69+
# left_tree = solemodel(tree.children[1]; with_stats, kwargs...)
70+
# right_tree = solemodel(tree.children[2]; with_stats, kwargs...)
71+
# else
72+
# error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
73+
# end
74+
75+
# # Aggregate info (e.g., supporting predictions) from children
76+
# info = (;
77+
# supporting_predictions=[left_tree.info[:supporting_predictions]..., right_tree.info[:supporting_predictions]...],
78+
# this=_makeleaf(tree.leaf),
79+
# xgboost_gain = tree.gain,
80+
# xgboost_yes = tree.yes,
81+
# xgboost_no = tree.no,
82+
# xgboost_cover = tree.cover,
83+
# )
84+
85+
# # Create and return a Sole Branch
86+
# return Branch(antecedent, left_tree, right_tree, info)
87+
# end
88+
# solemodel(model)
89+
90+
91+
end

src/SoleModels.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ export root
5656
export nnodes, nleaves
5757
export height
5858

59-
export DecisionForest
60-
export trees
59+
export DecisionEnsemble, models
60+
export DecisionForest, trees
6161

6262
export MixedModel
6363

src/print.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ end
459459

460460
function printmodel(
461461
io::IO,
462-
m::DecisionForest;
462+
m::DecisionEnsemble;
463463
header = DEFAULT_HEADER,
464464
indentation_str = "",
465465
indentation = default_indentation,
@@ -496,22 +496,22 @@ function printmodel(
496496
end
497497

498498
########################################################################################
499-
depth == 0 && show_symbols && print(io, "$(MODEL_SYMBOL) Forest of $(ntrees(m)) trees")
499+
depth == 0 && show_symbols && print(io, "$(MODEL_SYMBOL) Ensemble{$(outcometype(m))} of $(nmodels(m)) models of type $(modelstype(m))")
500500
if isnothing(max_depth) || depth < max_depth
501501
_show_rule_metrics = show_rule_metrics
502502
println(io, "$(indentation_list_children)")
503-
for (i_tree, tree) in enumerate(trees(m))
504-
if i_tree < ntrees(m)
505-
pipe = indentation_any_first*"[$i_tree/$(ntrees(m))]┐"
503+
for (i_model, model) in enumerate(models(m))
504+
if i_model < nmodels(m)
505+
pipe = indentation_any_first*"[$i_model/$(nmodels(m))]┐"
506506
pad_str = indentation_str*indentation_any_space*repeat(indentation_hspace, length(pipe)-length(indentation_any_space)-1-1)
507507
ind_str = pad_str*indentation_last_space
508508
else
509-
pipe = indentation_last_first*"[$i_tree/$(ntrees(m))]┐"
509+
pipe = indentation_last_first*"[$i_model/$(nmodels(m))]┐"
510510
ind_str = indentation_str*indentation_last_space*repeat(indentation_hspace, length(pipe)-length(indentation_last_space)-1-1)*indentation_last_space
511511
end
512512
print(io, pipe)
513513

514-
@_print_submodel io tree ind_str indentation depth max_depth show_subtree_info _show_rule_metrics show_subtree_metrics show_shortforms show_intermediate_finals tree_mode show_symbols syntaxstring_kwargs parenthesize_atoms kwargs
514+
@_print_submodel io model ind_str indentation depth max_depth show_subtree_info _show_rule_metrics show_subtree_metrics show_shortforms show_intermediate_finals tree_mode show_symbols syntaxstring_kwargs parenthesize_atoms kwargs
515515
end
516516
else
517517
depth != 0 && print(io, " ")

0 commit comments

Comments
 (0)