Skip to content

Commit c333d5c

Browse files
committed
Add explanations to evaluaterule. Add print for DecisionSet
1 parent a09d5ba commit c333d5c

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

src/evaluate.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,18 +316,34 @@ function evaluaterule(
316316
rule::Rule{L},
317317
X::AbstractInterpretationSet,
318318
Y::AbstractVector{<:Label};
319+
compute_explanations = false,
319320
kwargs...,
320321
) where {L<:CLabel}
321322
classmask = (Y .== outcome(consequent(rule)))
322-
checkmask = checkantecedent(rule, X; kwargs...)
323+
checkmask, explanations = begin
324+
if compute_explanations
325+
disjs = SoleLogics.disjuncts(SoleLogics.LeftmostDisjunctiveForm(antecedent(rule)))
326+
checkmatrix = hcat([check(disj, X; kwargs...) for disj in disjs]...)
327+
# @show checkmatrix
328+
checkmask = map(any, eachrow(checkmatrix))
329+
explanations = [disjs[checkrow] for checkrow in eachrow(checkmatrix)]
330+
checkmask, explanations
331+
else
332+
checkmask = checkantecedent(rule, X; kwargs...)
333+
explanations = nothing
334+
checkmask, explanations
335+
end
336+
end
323337
class_checkmask = checkmask[classmask]
324338
anticlass_checkmask = checkmask[(!).(classmask)]
325-
return (;
339+
out = (;
326340
classmask = classmask,
327341
checkmask = checkmask,
328342
sensitivity = sum(class_checkmask)/length(class_checkmask),
329343
specificity = 1-(sum(anticlass_checkmask)/length(anticlass_checkmask)),
344+
explanations = explanations,
330345
)
346+
return out
331347
end
332348

333349
# TODO: if delays not in info(m) ?

src/print.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,68 @@ function printmodel(
457457
nothing
458458
end
459459

460+
function printmodel(
461+
io::IO,
462+
m::DecisionSet;
463+
header = DEFAULT_HEADER,
464+
indentation_str = "",
465+
indentation = default_indentation,
466+
depth = 0,
467+
max_depth = nothing,
468+
show_subtree_info = false,
469+
show_rule_metrics = true,
470+
show_subtree_metrics = false,
471+
show_metrics = false,
472+
show_shortforms = false,
473+
show_intermediate_finals = false,
474+
tree_mode = false,
475+
show_symbols = true,
476+
syntaxstring_kwargs = (;),
477+
#
478+
parenthesize_atoms = true,
479+
kwargs...,
480+
)
481+
(
482+
indentation_list_children,
483+
indentation_hspace,
484+
indentation_any_first,
485+
indentation_any_space,
486+
indentation_last_first,
487+
indentation_last_space
488+
) = indentation
489+
if header != false
490+
_typestr = string(header == true ? typeof(m) :
491+
header == :brief ? nameof(typeof(m)) :
492+
error("Unexpected value for parameter header: $(header).")
493+
)
494+
println(io, "$(indentation_str)$(_typestr)$((length(info(m)) == 0) ?
495+
"" : "\n$(indentation_str)Info: $(info(m))")")
496+
end
497+
depth == 0 && show_symbols && print(io, MODEL_SYMBOL)
498+
########################################################################################
499+
_show_rule_metrics = show_rule_metrics
500+
# TODO show this metrics if show_metrics
501+
########################################################################################
502+
if isnothing(max_depth) || depth < max_depth
503+
println(io, "$(indentation_list_children)")
504+
for (i_rule, rule) in enumerate(rules(m))
505+
# pipe = indentation_any_first
506+
pipe = (i_rule != nrules(m) ? indentation_any_first : indentation_last_first)*"[$(i_rule)/$(nrules(m))]"
507+
# println(io, "$(indentation_str*pipe)$(syntaxstring(antecedent(rule); (haskey(info(rule), :syntaxstring_kwargs) ? info(rule).syntaxstring_kwargs : (;))..., syntaxstring_kwargs..., parenthesize_atoms = parenthesize_atoms, kwargs...))")
508+
pad_str = indentation_str*indentation_any_space*repeat(indentation_hspace, length(pipe)-length(indentation_any_space)-1)
509+
# print(io, "$(pad_str*indentation_last_first)")
510+
ind_str = pad_str*indentation_last_space
511+
# @_print_submodel io consequent(rule) ind_str indentation depth max_depth show_subtree_info false show_subtree_metrics show_shortforms show_intermediate_finals tree_mode show_symbols syntaxstring_kwargs parenthesize_atoms kwargs
512+
print(io, pipe)
513+
@_print_submodel io rule ind_str indentation depth max_depth show_subtree_info _show_rule_metrics show_subtree_metrics show_shortforms show_intermediate_finals tree_mode show_symbols syntaxstring_kwargs parenthesize_atoms kwargs
514+
end
515+
else
516+
depth != 0 && print(io, " ")
517+
println(io, "[...]")
518+
end
519+
nothing
520+
end
521+
460522
function printmodel(
461523
io::IO,
462524
m::DecisionEnsemble;

src/utils/models/other.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ struct DecisionSet{O} <: AbstractModel{O}
336336
end
337337

338338
rules(m::DecisionSet) = m.rules
339+
nrules(m::DecisionSet) = length(rules(m))
339340

340341
iscomplete(m::DecisionSet) = m.iscomplete
341342
isnonoverlapping(m::DecisionSet) = m.isnonoverlapping

0 commit comments

Comments
 (0)