Skip to content

Commit 93800fd

Browse files
committed
Merge branch 'dev' into gio/refractoring
2 parents 86cb6b7 + b808555 commit 93800fd

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ Graphs = "1.8"
5252
HTTP = "1.9"
5353
IterTools = "1"
5454
Lazy = "0.15.1"
55-
MLJ = "0.20"
56-
MLJBase = "1.6"
55+
MLJ = "0.19 - 0.20"
56+
MLJBase = "1.6 - 1.7"
5757
MLJDecisionTreeInterface = "0.4"
58-
MLJModelInterface = "1.8.0"
58+
MLJModelInterface = "1.8"
5959
PrettyTables = "2.2"
6060
ProgressMeter = "1"
6161
Random = "1"

src/rule-extraction.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -386,20 +386,22 @@ _listrules(m::DecisionTree; kwargs...) = _listrules(root(m); kwargs...)
386386

387387
function _listrules(
388388
m::DecisionEnsemble;
389-
suppress_parity_warning = false,
389+
# weights::Union{Nothing, AbstractVector} = nothing,
390+
suppress_parity_warning = true,
390391
kwargs...
391392
)
392-
error("TODO check method & implement more efficient strategies for specific cases.")
393-
modelrules = [listrules(subm; kwargs...) for subm in models(m)]
393+
# error("TODO check method & implement more efficient strategies for specific cases.")
394+
modelrules = [_listrules(subm; kwargs...) for subm in models(m)]
394395
@assert all(r->consequent(r) isa ConstantModel, Iterators.flatten(modelrules))
395396

396397
IterTools.imap(rulecombination->begin
397398
rulecombination = collect(rulecombination)
398399
ant = join_antecedents(antecedent.(rulecombination))
399-
cons = bestguess(outcome.(consequent.(rulecombination)); suppress_parity_warning)
400-
infos = info.(rulecombination)
401-
# TODO @show infos; info = (;)
402-
Rule(ant, cons)
400+
o_cons = bestguess(outcome.(consequent.(rulecombination)), m.weights; suppress_parity_warning)
401+
i_cons = merge(info.(consequent.(rulecombination))...)
402+
cons = ConstantModel(o_cons, i_cons)
403+
infos = merge(info.(rulecombination)...)
404+
Rule(ant, cons, infos)
403405
end, Iterators.product(modelrules...)
404406
)
405407
end

0 commit comments

Comments
 (0)