Skip to content

Commit c0e6e56

Browse files
committed
Add haslistrules isensemble
1 parent 0728ba9 commit c0e6e56

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

src/SoleModels.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ export DecisionForest, trees
6262

6363
export MixedModel
6464

65-
export solemodel
65+
export haslistrules, solemodel
66+
67+
export isensemble
6668

6769
include("types/model.jl")
6870
include("types/AbstractTrees.jl")

src/rule-extraction.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,13 @@ See also [`listrules`](@ref), [`Rule`](@ref)], [`issymbolicmodel`](@ref).
4242
"""
4343
struct PlainRuleExtractor <: RuleExtractor end
4444
isexact(::PlainRuleExtractor) = true
45-
extractrules(::PlainRuleExtractor, m, args...; kwargs...) = listrules(m, args...; kwargs...)
46-
45+
function extractrules(::PlainRuleExtractor, m, args...; kwargs...)
46+
if haslistrules(m)
47+
listrules(m, args...; kwargs...)
48+
else
49+
listrules(solemodel(m), args...; kwargs...)
50+
end
51+
end
4752

4853
############################################################################################
4954
############################################################################################

src/types/api.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
# TODO document, together with issymbolic and listrules
2+
"""
3+
haslistrules(m::Any)
4+
5+
This function extracts symbolic final rules from a symbolic model..
6+
7+
See also [`AbstractModel`](@ref), [`listrules`](@ref)
8+
[`LeafModel`](@ref).
9+
"""
10+
haslistrules(m) = false
11+
haslistrules(m::AbstractModel) = true
112

213
# TODO document, together with issymbolic and listrules
314
"""
415
solemodel(m::Any)
516
617
This function translates a symbolic model to a symbolic model using the structures defined in SoleModel.jl.
7-
# Interface
818
919
See also [`AbstractModel`](@ref), [`ConstantModel`](@ref), [`FunctionModel`](@ref),
1020
[`LeafModel`](@ref).
@@ -58,3 +68,7 @@ natoms(m::AbstractModel) = error("Please, provide method natoms(::$(typeof(m))).
5868
nconnectives(m::AbstractModel) = error("Please, provide method nconnectives(::$(typeof(m))).")
5969
"""$doc_syntax_utils_models"""
6070
nsyntaxleaves(m::AbstractModel) = error("Please, provide method nsyntaxleaves(::$(typeof(m))).")
71+
72+
73+
# TODO
74+
isensemble(m) = false

src/utils/models/ensembles.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ struct DecisionEnsemble{O,T<:AbstractModel,A<:Base.Callable,W<:Union{Nothing,Abs
9898

9999
end
100100

101+
102+
isensemble(m::DecisionEnsemble) = true
103+
101104
modelstype(m::DecisionEnsemble{O,T}) where {O,T} = T
102105
models(m::DecisionEnsemble) = m.models
103106
nmodels(m::DecisionEnsemble) = length(models(m))
@@ -314,6 +317,8 @@ struct MaxDecisionBag{O,TO<:AbstractModel,TU<:AbstractModel
314317
end
315318
end
316319

320+
isensemble(m::MaxDecisionBag) = true
321+
317322
function apply(m::MaxDecisionBag, d::AbstractInterpretation; suppress_parity_warning = false, kwargs...)
318323
weights = [apply(wm, d; suppress_parity_warning, kwargs...) for wm in m.weight_producing_models]
319324
om = m.output_producing_models[argmax(weights)]

0 commit comments

Comments
 (0)