Skip to content

Commit 7475523

Browse files
committed
Fix XGBoost ext
1 parent 1528d0d commit 7475523

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

ext/XGBoostExt.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,31 @@ using XGBoost
44

55
import Sole: alphabet, solemodel
66

7-
function alphabet(model::XGBoost.Booster; kwargs...)
8-
function _alphabet!(a::Vector, model::XGBoost.Booster; kwargs...)
9-
return a
10-
end
11-
function _alphabet!(a::Vector, tree::XGBoost.Node; kwargs...)
12-
# Base case: if it's a leaf node
13-
if length(tree.children) == 0
14-
return a
15-
end
16-
17-
# Recursive case: split node
18-
feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
19-
condition = ScalarCondition(feature, (<), tree.split_condition)
20-
push!(a, condition)
21-
if length(tree.children) == 2
22-
_alphabet!(a, tree.children[1]; with_stats, kwargs...)
23-
_alphabet!(a, tree.children[2]; with_stats, kwargs...)
24-
else
25-
error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
26-
end
27-
return a
28-
end
29-
_alphabet!([], model; kwargs...)
30-
end
7+
# TODO fix and test
8+
# function alphabet(model::XGBoost.Booster; kwargs...)
9+
# function _alphabet!(a::Vector, model::XGBoost.Booster; kwargs...)
10+
# return a
11+
# end
12+
# function _alphabet!(a::Vector, tree::XGBoost.Node; kwargs...)
13+
# # Base case: if it's a leaf node
14+
# if length(tree.children) == 0
15+
# return a
16+
# end
17+
18+
# # Recursive case: split node
19+
# feature = Sole.VariableValue(tree.split isa String ? Symbol(tree.split) : tree.split)
20+
# condition = ScalarCondition(feature, (<), tree.split_condition)
21+
# push!(a, condition)
22+
# if length(tree.children) == 2
23+
# _alphabet!(a, tree.children[1]; with_stats, kwargs...)
24+
# _alphabet!(a, tree.children[2]; with_stats, kwargs...)
25+
# else
26+
# error("Found $(length(tree.children)) children while 2 were expected: $(tree.children).")
27+
# end
28+
# return a
29+
# end
30+
# _alphabet!([], model; kwargs...)
31+
# end
3132

3233

3334
# TODO fix and test. Problem: where are the tree weights? How do I write this in the multi-class case?

0 commit comments

Comments
 (0)