Skip to content

Commit 4d1138b

Browse files
kleinschmidtpaldayararslan
authored
relax some checks on eltypes for contrasts and add tests (#242)
* relax some checks on eltypes for contrasts and add tests * actually use CSV * remove test for different eltype (wrong on 1.0) * compat bounds for CSV * drop 1.0, use WeakRefStrings in tests Co-authored-by: Phillip Alday <me@phillipalday.com> Co-authored-by: Alex Arslan <ararslan@comcast.net> Co-authored-by: Phillip Alday <palday@users.noreply.github.com>
1 parent 61de82a commit 4d1138b

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ ShiftedArrays = "1"
2323
StatsBase = "0.33.5"
2424
StatsFuns = "0.9"
2525
Tables = "0.2, 1"
26+
WeakRefStrings = "1"
2627
julia = "1.6"
2728

2829
[extras]
2930
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
3031
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3132
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3233
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
34+
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
3335

3436
[targets]
35-
test = ["CategoricalArrays", "DataFrames", "Statistics", "Test"]
37+
test = ["CategoricalArrays", "DataFrames", "Statistics", "Test", "WeakRefStrings"]

src/contrasts.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ end
120120
# only check equality of matrix, termnames, and levels, and that the type is the
121121
# same for the contrasts (values are irrelevant). This ensures that the two
122122
# will behave identically in creating modelmatrix columns
123-
Base.:(==)(a::ContrastsMatrix{C,T}, b::ContrastsMatrix{C,T}) where {C<:AbstractContrasts,T} =
123+
Base.:(==)(a::ContrastsMatrix{C}, b::ContrastsMatrix{C}) where {C<:AbstractContrasts} =
124124
a.matrix == b.matrix &&
125125
a.termnames == b.termnames &&
126126
a.levels == b.levels
@@ -166,18 +166,19 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
166166
# 3. contrast levels missing from data: would have empty columns, generate a
167167
# rank-deficient model matrix.
168168
c_levels = something(DataAPI.levels(contrasts), levels)
169-
if eltype(c_levels) != eltype(levels)
170-
throw(ArgumentError("mismatching levels types: got $(eltype(levels)), expected " *
171-
"$(eltype(c_levels)) based on contrasts levels."))
172-
end
169+
173170
mismatched_levels = symdiff(c_levels, levels)
174171
if !isempty(mismatched_levels)
175172
throw(ArgumentError("contrasts levels not found in data or vice-versa: " *
176173
"$mismatched_levels." *
177-
"\n Data levels: $levels." *
178-
"\n Contrast levels: $c_levels"))
174+
"\n Data levels ($(eltype(levels))): $levels." *
175+
"\n Contrast levels ($(eltype(c_levels))): $c_levels"))
179176
end
180177

178+
# do conversion AFTER checking for levels so users get a nice error message
179+
# when they've made a mistake with the level types
180+
c_levels = convert(Vector{T}, c_levels)
181+
181182
n = length(c_levels)
182183
if n == 0
183184
throw(ArgumentError("empty set of levels found (need at least two to compute " *
@@ -187,7 +188,7 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
187188
"compute contrasts)."))
188189
end
189190

190-
# find index of base level. use contrasts.base, then default (1).
191+
# find index of base level. use baselevel(contrasts), then default (1).
191192
base_level = baselevel(contrasts)
192193
baseind = base_level === nothing ?
193194
1 :

test/contrasts.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,5 +355,25 @@
355355
@testset "Non-unique levels" begin
356356
@test_throws ArgumentError ContrastsMatrix(DummyCoding(), ["a", "a", "b"])
357357
end
358+
359+
@testset "other string types" begin
360+
using WeakRefStrings
361+
362+
using StatsModels: ContrastsMatrix
363+
using DataAPI: levels
364+
365+
x = ["a", "b", "c", "a", "a", "b"]
366+
x1 = WeakRefStrings.String1.(x)
367+
x1_levs = levels(x1)
368+
369+
@test issetequal(x, x1)
370+
371+
c1 = ContrastsMatrix(DummyCoding(), x1_levs)
372+
c = ContrastsMatrix(DummyCoding(levels=["a", "b", "c"]), x1_levs)
373+
@test c == c1
374+
@test eltype(c.levels) == eltype(c1.levels)
375+
376+
@test_throws ArgumentError ContrastsMatrix(DummyCoding(levels=[1, 2, 3]), x1_levs)
377+
end
358378

359379
end

0 commit comments

Comments
 (0)