Skip to content

Commit cdd301a

Browse files
committed
Merge branch 'dev' of github.com:alberto-paparella/SoleModels.jl into alberto-paparella-dev
2 parents a514802 + 2cbb8cb commit cdd301a

File tree

10 files changed

+237
-7
lines changed

10 files changed

+237
-7
lines changed

Project.toml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
1818
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
1919
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2020
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
21-
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
2221
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2322
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2423
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -33,6 +32,18 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3332
ThreadSafeDicts = "4239201d-c60e-5e0a-9702-85d713665ba7"
3433
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
3534

35+
[weakdeps]
36+
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
37+
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
38+
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
39+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
40+
41+
[extensions]
42+
DecisionTreeExt = "DecisionTree"
43+
MLJExt = "MLJ"
44+
MLJDecisionTreeInterfaceExt = "MLJDecisionTreeInterface"
45+
MLJModelInterfaceExt = "MLJModelInterface"
46+
3647
[compat]
3748
AbstractTrees = "0.4"
3849
BenchmarkTools = "1"
@@ -63,13 +74,10 @@ julia = "1"
6374

6475
[extras]
6576
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
66-
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
67-
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
6877
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
6978
MultiData = "8cc5100c-b3d1-4f82-90cb-0ea93d317aba"
7079
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
71-
SoleDecisionTreeInterface = "de8eae22-3630-40e0-868c-abfc4c1bb3da"
7280
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7381

7482
[targets]
75-
test = ["Test", "Markdown", "InteractiveUtils", "PlutoUI", "MultiData", "MLJ", "MLJDecisionTreeInterface", "SoleDecisionTreeInterface"]
83+
test = ["Test", "Markdown", "InteractiveUtils", "PlutoUI", "MultiData"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ end
8585

8686
Then, port it to Sole and play with it:
8787
```julia
88-
Pkg.add("SoleDecisionTreeInterface"); using SoleDecisionTreeInterface
88+
Pkg.add("DecisionTree"); import DecisionTree as DT
8989

9090
# Convert to 🌞-compliant model
9191
🌲 = solemodel(🌱);

ext/DecisionTreeExt.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module DecisionTreeExt
2+
3+
using SoleModels
4+
5+
import DecisionTree as DT
6+
7+
function SoleModels.solemodel(model::DT.Ensemble, args...; kwargs...)
8+
return SoleModels.DecisionForest(map(t -> SoleModels.DecisionTree(SoleModels.solemodel(t, args...; kwargs...)), model.trees))
9+
end
10+
11+
function SoleModels.solemodel(tree::DT.InfoNode, keep_condensed = false; use_featurenames = true, kwargs...)
12+
# @show fieldnames(typeof(tree))
13+
use_featurenames = use_featurenames ? tree.info.featurenames : false
14+
root, info = begin
15+
if keep_condensed
16+
root = SoleModels.solemodel(tree.node; use_featurenames = use_featurenames, kwargs...)
17+
info = (;
18+
apply_preprocess=(y -> UInt32(findfirst(x -> x == y, tree.info.classlabels))),
19+
apply_postprocess=(y -> tree.info.classlabels[y]),
20+
)
21+
root, info
22+
else
23+
root = SoleModels.solemodel(tree.node; replace_classlabels = tree.info.classlabels, use_featurenames = use_featurenames, kwargs...)
24+
info = (;)
25+
root, info
26+
end
27+
end
28+
29+
info = merge(info, (;
30+
featurenames=tree.info.featurenames,
31+
#
32+
supporting_predictions=root.info[:supporting_predictions],
33+
supporting_labels=root.info[:supporting_labels],
34+
)
35+
)
36+
return DecisionTree(root, info)
37+
end
38+
39+
# function SoleModels.solemodel(tree::DT.Root)
40+
# root = SoleModels.solemodel(tree.node)
41+
# # @show fieldnames(typeof(tree))
42+
# info = (;
43+
# n_feat = tree.n_feat,
44+
# featim = tree.featim,
45+
# supporting_predictions = root.info[:supporting_predictions],
46+
# supporting_labels = root.info[:supporting_labels],
47+
# )
48+
# return DecisionTree(root, info)
49+
# end
50+
51+
function SoleModels.solemodel(tree::DT.Node; replace_classlabels = nothing, use_featurenames = false)
52+
test_operator = (<)
53+
# @show fieldnames(typeof(tree))
54+
feature = (use_featurenames != false) ? VariableValue(use_featurenames[tree.featid]) : VariableValue(tree.featid)
55+
cond = ScalarCondition(feature, test_operator, tree.featval)
56+
antecedent = Atom(cond)
57+
lefttree = SoleModels.solemodel(tree.left; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
58+
righttree = SoleModels.solemodel(tree.right; replace_classlabels = replace_classlabels, use_featurenames = use_featurenames)
59+
info = (;
60+
supporting_predictions = [lefttree.info[:supporting_predictions]..., righttree.info[:supporting_predictions]...],
61+
supporting_labels = [lefttree.info[:supporting_labels]..., righttree.info[:supporting_labels]...],
62+
)
63+
return Branch(antecedent, lefttree, righttree, info)
64+
end
65+
66+
function SoleModels.solemodel(tree::DT.Leaf; replace_classlabels = nothing, use_featurenames = false)
67+
# @show fieldnames(typeof(tree))
68+
prediction = tree.majority
69+
labels = tree.values
70+
if !isnothing(replace_classlabels)
71+
prediction = replace_classlabels[prediction]
72+
labels = replace_classlabels[labels]
73+
end
74+
info = (;
75+
supporting_predictions = fill(prediction, length(labels)),
76+
supporting_labels = labels,
77+
)
78+
return SoleModels.ConstantModel(prediction, info)
79+
end
80+
81+
end

ext/MLJDecisionTreeInterfaceExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module MLJDecisionTreeInterfaceExt
2+
3+
using SoleModels, MLJDecisionTreeInterface
4+
5+
end

ext/MLJExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module MLJExt
2+
3+
using SoleModels, MLJ
4+
5+
end

ext/MLJModelInterfaceExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module MLJModelInterfaceExt
2+
3+
using SoleModels, MLJModelInterface
4+
5+
end

test/DecisionTreeExt/forest.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Test
2+
3+
using MLJ
4+
using MLJBase
5+
using DataFrames
6+
7+
using MLJDecisionTreeInterface
8+
using BenchmarkTools
9+
using Sole
10+
11+
import DecisionTree as DT
12+
13+
X, y = @load_iris
14+
X = DataFrame(X)
15+
16+
train_ratio = 0.8
17+
18+
train, test = partition(eachindex(y), train_ratio, shuffle=true)
19+
X_train, y_train = X[train, :], y[train]
20+
X_test, y_test = X[test, :], y[test]
21+
22+
println("Training set size: ", size(X_train), " - ", size(y_train))
23+
println("Test set size: ", size(X_test), " - ", size(y_test))
24+
println("Training set type: ", typeof(X_train), " - ", typeof(y_train))
25+
println("Test set type: ", typeof(X_test), " - ", typeof(y_test))
26+
27+
Forest = MLJ.@load RandomForestClassifier pkg=DecisionTree
28+
29+
model = Forest(
30+
max_depth=3,
31+
min_samples_leaf=1,
32+
min_samples_split=2,
33+
n_trees = 10,
34+
)
35+
36+
# Bind the model and data into a machine
37+
mach = machine(model, X_train, y_train)
38+
# Fit the model
39+
fit!(mach)
40+
41+
42+
sole_forest = solemodel(fitted_params(mach).forest)
43+
44+
@test SoleData.scalarlogiset(X_test; allow_propositional = true) isa PropositionalLogiset
45+
46+
# Make test instances flow into the model
47+
apply!(sole_forest, X_test, y_test)
48+
49+
# apply!(sole_forest, X_test, y_test, mode = :append)
50+
51+
sole_forest = @test_nowarn @btime solemodel(fitted_params(mach).forest, true)
52+
sole_forest = @test_nowarn @btime solemodel(fitted_params(mach).forest, false)
53+
54+
printmodel(sole_forest; max_depth = 7, show_intermediate_finals = true, show_metrics = true)
55+
56+
printmodel.(listrules(sole_forest, min_lift = 1.0, min_ninstances = 0); show_metrics = true);

test/DecisionTreeExt/tree.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using Test
2+
3+
using MLJ
4+
using MLJBase
5+
using DataFrames
6+
7+
using MLJDecisionTreeInterface
8+
using BenchmarkTools
9+
using Sole
10+
11+
import DecisionTree as DT
12+
13+
X, y = @load_iris
14+
X = DataFrame(X)
15+
16+
train_ratio = 0.8
17+
18+
train, test = partition(eachindex(y), train_ratio, shuffle=true)
19+
X_train, y_train = X[train, :], y[train]
20+
X_test, y_test = X[test, :], y[test]
21+
22+
println("Training set size: ", size(X_train), " - ", size(y_train))
23+
println("Test set size: ", size(X_test), " - ", size(y_test))
24+
println("Training set type: ", typeof(X_train), " - ", typeof(y_train))
25+
println("Test set type: ", typeof(X_test), " - ", typeof(y_test))
26+
27+
Tree = MLJ.@load DecisionTreeClassifier pkg=DecisionTree
28+
29+
model = Tree(
30+
max_depth=-1,
31+
min_samples_leaf=1,
32+
min_samples_split=2,
33+
)
34+
35+
# Bind the model and data into a machine
36+
mach = machine(model, X_train, y_train)
37+
# Fit the model
38+
fit!(mach)
39+
40+
41+
sole_dt = solemodel(fitted_params(mach).tree)
42+
43+
@test SoleData.scalarlogiset(X_test; allow_propositional = true) isa PropositionalLogiset
44+
45+
# Make test instances flow into the model
46+
apply!(sole_dt, X_test, y_test)
47+
48+
# apply!(sole_dt, X_test, y_test, mode = :append)
49+
50+
sole_dt = @test_nowarn @btime solemodel(fitted_params(mach).tree, true)
51+
sole_dt = @test_nowarn @btime solemodel(fitted_params(mach).tree, false)
52+
53+
printmodel(sole_dt; max_depth = 7, show_intermediate_finals = true, show_metrics = true)
54+
55+
printmodel.(listrules(sole_dt, min_lift = 1.0, min_ninstances = 0); show_metrics = true);
56+
57+
printmodel.(listrules(sole_dt, min_lift = 1.0, min_ninstances = 0); show_metrics = true, show_subtree_metrics = true);
58+
59+
printmodel.(listrules(sole_dt, min_lift = 1.0, min_ninstances = 0); show_metrics = true, show_subtree_metrics= true, tree_mode=true);
60+
61+
readmetrics.(listrules(sole_dt; min_lift=1.0, min_ninstances = 0))
62+
63+
printmodel.(listrules(sole_dt, min_lift = 1.0, min_ninstances = 0); show_metrics = true);
64+
65+
interesting_rules = listrules(sole_dt; min_lift=1.0, min_ninstances = 0, custom_thresholding_callback = (ms)->ms.coverage*ms.ninstances >= 4)
66+
# printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, ));
67+
printmodel.(sort(interesting_rules, by = readmetrics); show_metrics = (; round_digits = nothing, additional_metrics = (; length = r->natoms(antecedent(r)))));
68+
69+
@test_broken joinrules(interesting_rules) == "Check this result."

test/juliacon2024.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ end
3434
🌱 = fitted_params(mach).tree
3535

3636
# Convert to 🌞-compliant model
37-
using SoleDecisionTreeInterface
37+
import DecisionTree as DT
3838
🌲 = solemodel(🌱);
3939

4040
# Print model

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ test_suites = [
2121
("Rules", ["juliacon2024.jl", ]),
2222
("Linear forms", ["linear-form-utilities.jl", ]),
2323
("Pluto Demo", ["$(dirname(dirname(pathof(SoleModels))))/pluto-demo.jl", ]),
24+
("DecisionTreeExt", ["DecisionTreeExt/tree.jl"]) #, "DecisionTreeExt/forest.jl"])
2425
]
2526

2627
@testset "SoleModels.jl" begin

0 commit comments

Comments
 (0)