1
1
module DecisionTreeExt
2
2
3
3
using SoleModels
4
+ import SoleModels: solemodel
4
5
5
6
import DecisionTree as DT
6
7
7
- function SoleModels. solemodel (model:: DT.Ensemble , args... ; kwargs... )
8
- return SoleModels. DecisionForest (map (t -> SoleModels. DecisionTree (SoleModels. solemodel (t, args... ; kwargs... )), model. trees))
8
+ function SoleModels. solemodel (
9
+ model:: DT.Ensemble ,
10
+ classlabels = nothing ,
11
+ featurenames = nothing ,
12
+ args... ;
13
+ keep_condensed = true ,
14
+ kwargs...
15
+ )
16
+ if isnothing (classlabels)
17
+ error (" Please, provide classlabels argument, as in solemodel(forest, classlabels; kwargs...). If your forest was trained via MLJ, use `classlabels = (mach).fitresult[2][sortperm((mach).fitresult[3])]`." )
18
+ end
19
+ if keep_condensed
20
+ info = (;
21
+ apply_preprocess= (y -> UInt32 (findfirst (x -> x == y, classlabels))),
22
+ apply_postprocess= (y -> classlabels[y]),
23
+ )
24
+ keep_condensed = ! keep_condensed
25
+ # O = UInt32
26
+ else
27
+ info = (;)
28
+ # O = UInt32
29
+ end
30
+ trees = map (t -> SoleModels. solemodel (t, args... ; keep_condensed, featurenames, kwargs... ), model. trees)
31
+ # trees = map(t -> let b = SoleModels.solemodel(t, args...; keep_condensed, featurenames, kwargs...); SoleModels.DecisionTree(b,
32
+ # (;
33
+ # supporting_predictions=b.info[:supporting_predictions],
34
+ # supporting_labels=b.info[:supporting_labels],
35
+ # )
36
+ # ) end, model.trees)
37
+
38
+ if ! isnothing (featurenames)
39
+ info = merge (info, (; featurenames= featurenames, ))
40
+ end
41
+
42
+ info = merge (info, (;
43
+ supporting_predictions= vcat ([t. info[:supporting_predictions ] for t in trees]. .. ),
44
+ supporting_labels= vcat ([t. info[:supporting_labels ] for t in trees]. .. ),
45
+ )
46
+ )
47
+
48
+ if ! isnothing (classlabels)
49
+ O = eltype (classlabels)
50
+ # O = eltype(levels(classnames))
51
+ else
52
+ O = nothing
53
+ end
54
+
55
+ if isnothing (O)
56
+ m = DecisionEnsemble (trees, info)
57
+ else
58
+ m = DecisionEnsemble {O} (trees, info)
59
+ end
60
+ return m
9
61
end
10
62
11
- function SoleModels. solemodel (tree:: DT.InfoNode , keep_condensed = false ; use_featurenames = true , kwargs... )
63
+ function SoleModels. solemodel (tree:: DT.InfoNode ; keep_condensed = true , featurenames = true , classlabels = tree . info . classlabels , kwargs... )
12
64
# @show fieldnames(typeof(tree))
13
- use_featurenames = use_featurenames ? tree. info. featurenames : false
65
+ featurenames = featurenames == true ? tree. info. featurenames : featurenames
66
+
14
67
root, info = begin
15
68
if keep_condensed
16
- root = SoleModels. solemodel (tree. node; use_featurenames = use_featurenames , kwargs... )
69
+ root = SoleModels. solemodel (tree. node; featurenames , kwargs... )
17
70
info = (;
18
- apply_preprocess= (y -> UInt32 (findfirst (x -> x == y, tree . info . classlabels))),
19
- apply_postprocess= (y -> tree . info . classlabels[y]),
71
+ apply_preprocess= (y -> UInt32 (findfirst (x -> x == y, classlabels))),
72
+ apply_postprocess= (y -> classlabels[y]),
20
73
)
74
+ keep_condensed = ! keep_condensed
21
75
root, info
22
76
else
23
- root = SoleModels. solemodel (tree. node; replace_classlabels = tree . info . classlabels, use_featurenames = use_featurenames , kwargs... )
77
+ root = SoleModels. solemodel (tree. node; replace_classlabels = classlabels, featurenames , kwargs... )
24
78
info = (;)
25
79
root, info
26
80
end
@@ -33,7 +87,19 @@ function SoleModels.solemodel(tree::DT.InfoNode, keep_condensed = false; use_fea
33
87
supporting_labels= root. info[:supporting_labels ],
34
88
)
35
89
)
36
- return DecisionTree (root, info)
90
+
91
+ # if !isnothing(classlabels)
92
+ # O = eltype(classlabels)
93
+ # else
94
+ # O = nothing
95
+ # end
96
+
97
+ # if isnothing(O)
98
+ dt = DecisionTree (root, info)
99
+ # else
100
+ # dt = DecisionTree{O}(root, info)
101
+ # end
102
+ return dt
37
103
end
38
104
39
105
# function SoleModels.solemodel(tree::DT.Root)
48
114
# return DecisionTree(root, info)
49
115
# end
50
116
51
- function SoleModels. solemodel (tree:: DT.Node ; replace_classlabels = nothing , use_featurenames = false )
117
+ function SoleModels. solemodel (tree:: DT.Node ; replace_classlabels = nothing , featurenames = nothing , keep_condensed = false )
118
+ keep_condensed && error (" Cannot keep condensed DecisionTree.Node." )
52
119
test_operator = (< )
53
120
# @show fieldnames(typeof(tree))
54
- feature = (use_featurenames != false ) ? VariableValue (use_featurenames [tree. featid]) : VariableValue (tree. featid)
121
+ feature = ! isnothing (featurenames ) ? VariableValue (featurenames [tree. featid]) : VariableValue (tree. featid)
55
122
cond = ScalarCondition (feature, test_operator, tree. featval)
56
123
antecedent = Atom (cond)
57
- lefttree = SoleModels. solemodel (tree. left; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames )
58
- righttree = SoleModels. solemodel (tree. right; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames )
124
+ lefttree = SoleModels. solemodel (tree. left; replace_classlabels, featurenames )
125
+ righttree = SoleModels. solemodel (tree. right; replace_classlabels, featurenames )
59
126
info = (;
60
127
supporting_predictions = [lefttree. info[:supporting_predictions ]. .. , righttree. info[:supporting_predictions ]. .. ],
61
128
supporting_labels = [lefttree. info[:supporting_labels ]. .. , righttree. info[:supporting_labels ]. .. ],
62
129
)
63
130
return Branch (antecedent, lefttree, righttree, info)
64
131
end
65
132
66
- function SoleModels. solemodel (tree:: DT.Leaf ; replace_classlabels = nothing , use_featurenames = false )
133
+ function SoleModels. solemodel (tree:: DT.Leaf ; replace_classlabels = nothing , featurenames = nothing , keep_condensed = false )
134
+ keep_condensed && error (" Cannot keep condensed DecisionTree.Node." )
67
135
# @show fieldnames(typeof(tree))
68
136
prediction = tree. majority
69
137
labels = tree. values
0 commit comments