Skip to content

Commit cbdaac4

Browse files
paldaykleinschmidt
andauthored
support StatsModels 0.7 (#13)
* support StatsModels 0.7 * detect InteractionTerms and throw error * update doc * tupleterm * unused imports * add a few more tests * add test * mark as broken * brokener --------- Co-authored-by: Dave Kleinschmidt <dave.f.kleinschmidt@gmail.com>
1 parent 342b034 commit cbdaac4

File tree

6 files changed

+45
-26
lines changed

6 files changed

+45
-26
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RegressionFormulae"
22
uuid = "545c379f-4ec2-4339-9aea-38f2fb6a8ba2"
33
authors = ["Dave Kleinschmidt", "Phillip Alday"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
@@ -10,7 +10,7 @@ StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
1010
[compat]
1111
Combinatorics = "1"
1212
StatsBase = "0.33"
13-
StatsModels = "0.6.7"
13+
StatsModels = "0.7"
1414
julia = "1.6"
1515

1616
[extras]

src/RegressionFormulae.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ using StatsModels
44
using Combinatorics
55
using Base.Iterators
66

7-
using StatsModels: apply_schema
7+
using StatsModels: apply_schema, TupleTerm
88

9-
const TermTuple = NTuple{N, AbstractTerm} where N
109
const Schemas = Union{StatsModels.Schema, StatsModels.FullRank}
1110

1211
include("fulldummy.jl")

src/fulldummy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function StatsModels.apply_schema(
88
sch::StatsModels.FullRank,
99
Mod::Type{<:RegressionModel},
1010
)
11-
fulldummy(apply_schema.(t.args_parsed, Ref(sch), Mod)...)
11+
fulldummy(apply_schema.(t.args, Ref(sch), Mod)...)
1212
end
1313

1414
function fulldummy(t::CategoricalTerm)

src/nesting.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ function Base.:(/)(outer::CategoricalTerm, inner::AbstractTerm)
2121
return outer + fulldummy(outer) & inner
2222
end
2323

24-
function Base.:(/)(outer::CategoricalTerm, inner::TermTuple)
24+
function Base.:(/)(outer::CategoricalTerm, inner::TupleTerm)
2525
fd = fulldummy(outer)
2626
return mapfoldl(x -> fd & x, +, inner; init=outer)
2727
end
2828

29-
function Base.:(/)(outer::TermTuple, inner::Union{AbstractTerm, TermTuple})
29+
function Base.:(/)(outer::TupleTerm, inner::Union{AbstractTerm, TupleTerm})
3030
return outer[1:end-1] + last(outer) / inner
3131
end
3232

@@ -37,7 +37,7 @@ function Base.:(/)(outer::InteractionTerm, inner::AbstractTerm)
3737
return outer + outer & inner
3838
end
3939

40-
function Base.:(/)(outer::InteractionTerm, inner::TermTuple)
40+
function Base.:(/)(outer::InteractionTerm, inner::TupleTerm)
4141
# we should only get here via expansion where the interaction term,
4242
# but who knows what devious things users will try
4343
_fulldummycheck(outer)
@@ -54,13 +54,12 @@ function StatsModels.apply_schema(
5454
sch::StatsModels.FullRank,
5555
Mod::Type{<:RegressionModel},
5656
)
57-
length(t.args_parsed) == 2 ||
57+
length(t.args) == 2 ||
5858
throw(ArgumentError("malformed nesting term: $t (Exactly two arguments required)"))
5959

60-
any(x -> isa(x, ConstantTerm), t.args_parsed) && return t
60+
any(x -> isa(x, ConstantTerm), t.args) && return t
6161

62-
args = apply_schema.(t.args_parsed, Ref(sch), Mod)
62+
args = apply_schema.(t.args, Ref(sch), Mod)
6363

6464
return first(args) / last(args)
6565
end
66-

src/power.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ combinations_upto(x, n) = Iterators.flatten(combinations(x, i) for i in 1:n)
66
Generate all interactions of terms up to order ``n``.
77
88
!!! warning
9-
If any term is an `InteractionTerm`, then nonsensical interactions may
10-
arise, e.g. `a & a & b`.
9+
Embedded `InteractionTerms` (i.e. `(a + b + c & d)^2`) are not currently
10+
supported and result in an error.
1111
"""
12-
function Base.:(^)(args::TermTuple, deg::ConstantTerm{<:Integer})
12+
function Base.:(^)(args::TupleTerm, deg::ConstantTerm{<:Integer})
1313
deg.n > 0 || throw(ArgumentError("power should be greater than zero (got $deg)"))
14+
any(t isa InteractionTerm for t in args) &&
15+
throw(ArgumentError("powers of interaction terms not currently supported"))
1416
tuple(((&)(terms...) for terms in combinations_upto(args, deg.n))...)
1517
end
1618

17-
function Base.:(^)(::TermTuple, deg::AbstractTerm)
19+
function Base.:(^)(::TupleTerm, deg::AbstractTerm)
1820
throw(ArgumentError("power should be an integer constant (got $deg)"))
1921
end
2022

@@ -23,10 +25,12 @@ function StatsModels.apply_schema(
2325
sch::StatsModels.FullRank,
2426
ctx::Type{<:RegressionModel}
2527
)
26-
length(t.args_parsed) == 2 ||
28+
length(t.args) == 2 ||
2729
throw(ArgumentError("invalid term $t: should have exactly two arguments"))
28-
first, second = t.args_parsed
29-
second isa ConstantTerm{<:Integer} ||
30-
throw(ArgumentError("invalid term $t: power should be an integer (got $second)"))
31-
apply_schema.(first^second, Ref(sch), ctx)
30+
first, second = t.args
31+
32+
base = apply_schema(first, sch, ctx)
33+
return base^second
3234
end
35+
36+
# StatsModels

test/power.jl

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include("dummymod.jl")
66

77
dat = (; y=zeros(3), a=1:3, b=11:13, c=21:23, d=31:33, e=["u", "i", "o"])
88

9-
109
@testset "error checking" begin
1110
@test_throws ArgumentError (term(:b),) ^ term(:a)
1211
@test_throws ArgumentError (term(:b),) ^ term(2.5)
@@ -31,11 +30,29 @@ end
3130
m = fit(DummyMod, @formula(y ~ (a + b + e)^2), dat)
3231
@test coefnames(m) == ["(Intercept)", "a", "b", "e: o", "e: u",
3332
"a & b", "a & e: o", "a & e: u", "b & e: o", "b & e: u"]
33+
34+
# make sure inner function terms work
35+
m = fit(DummyMod, @formula(y ~ (a + b + log(a + b))^3), dat)
36+
@test coefnames(m) == ["(Intercept)", "a", "b", "log(a + b)",
37+
"a & b", "a & log(a + b)", "b & log(a + b)",
38+
"a & b & log(a + b)"]
39+
40+
# cursed but should technically work
41+
m = fit(DummyMod, @formula(y ~ (a + b + (c + d)^1)^2), dat)
42+
@test coefnames(m) == ["(Intercept)", "a", "b", "c", "d",
43+
"a & b", "a & c", "a & d",
44+
"b & c", "b & d", "c & d"]
45+
# not actually an InteractionTerm even if it's mathematically equivalent for
46+
# ContinuousTerms
47+
# throws an error and is broken
48+
# https://github.com/JuliaStats/StatsModels.jl/issues/290
49+
# m = fit(DummyMod, @formula(y ~ (a + protect(c * d))^2), dat)
50+
# @test coefnames(m) == ["(Intercept)", "a", "c * d", "a & c *d "]
51+
# to remind us :this: is broken
52+
@test_broken false
3453
end
3554

3655
@testset "embedded interactions" begin
37-
m = fit(DummyMod, @formula(y ~ (a + b + c * d)^3), dat)
38-
cn = coefnames(m)
39-
@test_broken !("a & c & c & d" in cn)
40-
@test_broken cn == unique(cn)
56+
@test_throws ArgumentError fit(DummyMod, @formula(y ~ (a + b + c & d)^3), dat)
57+
@test_throws ArgumentError fit(DummyMod, @formula(y ~ (a + b + c * d)^3), dat)
4158
end

0 commit comments

Comments
 (0)