Skip to content

Commit b808555

Browse files
committed
ensembletrees listrules
1 parent c0e6e56 commit b808555

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ Graphs = "1.8"
5151
HTTP = "1.9"
5252
IterTools = "1"
5353
Lazy = "0.15.1"
54-
MLJ = "0.20"
55-
MLJBase = "1.6"
54+
MLJ = "0.19 - 0.20"
55+
MLJBase = "1.6 - 1.7"
5656
MLJDecisionTreeInterface = "0.4"
57-
MLJModelInterface = "1.8.0"
57+
MLJModelInterface = "1.8"
5858
PrettyTables = "2.2"
5959
ProgressMeter = "1"
6060
Random = "1"

src/rule-extraction.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,20 +373,22 @@ _listrules(m::DecisionTree; kwargs...) = _listrules(root(m); kwargs...)
373373

374374
function _listrules(
375375
m::DecisionEnsemble;
376-
suppress_parity_warning = false,
376+
# weights::Union{Nothing, AbstractVector} = nothing,
377+
suppress_parity_warning = true,
377378
kwargs...
378379
)
379-
error("TODO check method & implement more efficient strategies for specific cases.")
380-
modelrules = [listrules(subm; kwargs...) for subm in models(m)]
380+
# error("TODO check method & implement more efficient strategies for specific cases.")
381+
modelrules = [_listrules(subm; kwargs...) for subm in models(m)]
381382
@assert all(r->consequent(r) isa ConstantModel, Iterators.flatten(modelrules))
382383

383384
IterTools.imap(rulecombination->begin
384385
rulecombination = collect(rulecombination)
385386
ant = join_antecedents(antecedent.(rulecombination))
386-
cons = bestguess(outcome.(consequent.(rulecombination)); suppress_parity_warning)
387-
infos = info.(rulecombination)
388-
# TODO @show infos; info = (;)
389-
Rule(ant, cons)
387+
o_cons = bestguess(outcome.(consequent.(rulecombination)), m.weights; suppress_parity_warning)
388+
i_cons = merge(info.(consequent.(rulecombination))...)
389+
cons = ConstantModel(o_cons, i_cons)
390+
infos = merge(info.(rulecombination)...)
391+
Rule(ant, cons, infos)
390392
end, Iterators.product(modelrules...)
391393
)
392394
end

0 commit comments

Comments
 (0)