Skip to content

Commit ce2b36f

Browse files
authored
add get_basis_names for massunivariate models (#208)
1 parent 875b04f commit ce2b36f

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

src/condense.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,28 @@ function get_basis_colnames(rhs::AbstractTerm)
175175
return colnames(rhs.basisfunction)
176176
end
177177

178+
"""
179+
get_basisnames(model::UnfoldModel)
180+
181+
Return the basisnames for all predictor terms as a vector.
182+
183+
The returned vector contains the name of the event type/basis, repeated by their actual coefficient number (after StatsModels.apply_schema / timeexpansion).
184+
If a model has more than one event type (e.g. stimulus and fixation), the vectors are concatenated.
185+
"""
186+
@traitfn function get_basis_names(m::T) where {T <: UnfoldModel; !ContinuousTimeTrait{T}}
187+
188+
# Extract the event names from the design
189+
design_keys = first.((Unfold.design(m)))
190+
191+
# Create a list of the basis names corresponding to each model term
192+
basisnames = String[]
193+
for (ix, event) in enumerate(design_keys)
194+
push!(basisnames, repeat([event], size(modelmatrix(m)[ix], 2))...)
195+
end
196+
return basisnames
197+
end
198+
199+
178200
@traitfn get_basis_names(m::T) where {T <: UnfoldModel; ContinuousTimeTrait{T}} =
179201
get_basis_names.(formulas(m))
180202
function get_basis_names(m::FormulaTerm)

0 commit comments

Comments
 (0)