Skip to content

Commit d60fb10

Browse files
authored
Fix implementation of Tables row interface (#279)
* Create valid rows iterator * Test that matrices via rows and cols are the same * Remove now-unnecessary optional overloads * Use egal for maybe better const prop * Better denote sections with comments * Increment version number * Organize tests into sections * Add row interface tests
1 parent 756b7ff commit d60fb10

File tree

3 files changed

+167
-103
lines changed

3 files changed

+167
-103
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "4.7.1"
6+
version = "4.7.2"
77

88
[deps]
99
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/tables.jl

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Tables and TableTraits interface
22

3-
## Chains
3+
####
4+
#### Chains
5+
####
46

57
function _check_columnnames(chn::Chains)
68
for name in names(chn)
@@ -11,8 +13,12 @@ function _check_columnnames(chn::Chains)
1113
end
1214
end
1315

16+
#### Tables interface
17+
1418
Tables.istable(::Type{<:Chains}) = true
1519

20+
# AbstractColumns interface
21+
1622
Tables.columnaccess(::Type{<:Chains}) = true
1723

1824
function Tables.columns(chn::Chains)
@@ -26,11 +32,11 @@ function Tables.getcolumn(chn::Chains, i::Int)
2632
return Tables.getcolumn(chn, Tables.columnnames(chn)[i])
2733
end
2834
function Tables.getcolumn(chn::Chains, nm::Symbol)
29-
if nm == :iteration
35+
if nm === :iteration
3036
iterations = range(chn)
3137
nchains = size(chn, 3)
3238
return repeat(iterations, nchains)
33-
elseif nm == :chain
39+
elseif nm === :chain
3440
chainids = chains(chn)
3541
niter = size(chn, 1)
3642
return repeat(chainids; inner = niter)
@@ -39,18 +45,13 @@ function Tables.getcolumn(chn::Chains, nm::Symbol)
3945
end
4046
end
4147

42-
Tables.rowaccess(::Type{<:Chains}) = true
48+
# row access
4349

44-
function Tables.rows(chn::Chains)
45-
_check_columnnames(chn)
46-
return chn
47-
end
50+
Tables.rowaccess(::Type{<:Chains}) = true
4851

49-
Tables.rowtable(chn::Chains) = Tables.rowtable(Tables.columntable(chn))
52+
Tables.rows(chn::Chains) = Tables.rows(Tables.columntable(chn))
5053

51-
function Tables.namedtupleiterator(chn::Chains)
52-
return Tables.namedtupleiterator(Tables.columntable(chn))
53-
end
54+
# optional Tables overloads
5455

5556
function Tables.schema(chn::Chains)
5657
_check_columnnames(chn)
@@ -60,17 +61,25 @@ function Tables.schema(chn::Chains)
6061
return Tables.Schema(nms, types)
6162
end
6263

64+
#### TableTraits interface
65+
6366
IteratorInterfaceExtensions.isiterable(::Chains) = true
6467
function IteratorInterfaceExtensions.getiterator(chn::Chains)
6568
return Tables.datavaluerows(Tables.columntable(chn))
6669
end
6770

6871
TableTraits.isiterabletable(::Chains) = true
6972

70-
## ChainDataFrame
73+
####
74+
#### ChainDataFrame
75+
####
76+
77+
#### Tables interface
7178

7279
Tables.istable(::Type{<:ChainDataFrame}) = true
7380

81+
# AbstractColumns interface
82+
7483
Tables.columnaccess(::Type{<:ChainDataFrame}) = true
7584

7685
Tables.columns(cdf::ChainDataFrame) = cdf
@@ -80,21 +89,19 @@ Tables.columnnames(::ChainDataFrame{<:NamedTuple{names}}) where {names} = names
8089
Tables.getcolumn(cdf::ChainDataFrame, i::Int) = cdf.nt[i]
8190
Tables.getcolumn(cdf::ChainDataFrame, nm::Symbol) = cdf.nt[nm]
8291

83-
Tables.rowaccess(::Type{<:ChainDataFrame}) = true
84-
85-
Tables.rows(cdf::ChainDataFrame) = cdf
92+
# row access
8693

87-
Tables.rowtable(cdf::ChainDataFrame) = Tables.rowtable(Tables.columntable(cdf))
94+
Tables.rowaccess(::Type{<:ChainDataFrame}) = true
8895

89-
function Tables.namedtupleiterator(cdf::ChainDataFrame)
90-
return Tables.namedtupleiterator(Tables.columntable(cdf))
91-
end
96+
Tables.rows(cdf::ChainDataFrame) = Tables.rows(Tables.columntable(cdf))
9297

9398
function Tables.schema(::ChainDataFrame{NamedTuple{names,T}}) where {names,T}
9499
types = ntuple(i -> eltype(fieldtype(T, i)), fieldcount(T))
95100
return Tables.Schema(names, types)
96101
end
97102

103+
#### TableTraits interface
104+
98105
IteratorInterfaceExtensions.isiterable(::ChainDataFrame) = true
99106
function IteratorInterfaceExtensions.getiterator(cdf::ChainDataFrame)
100107
return Tables.datavaluerows(Tables.columntable(cdf))

test/tables_tests.jl

Lines changed: 139 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,63 +13,95 @@ using DataFrames
1313

1414
@testset "Tables interface" begin
1515
@test Tables.istable(typeof(chn))
16-
@test Tables.columnaccess(typeof(chn))
17-
@test Tables.columns(chn) === chn
18-
@test Tables.columnnames(chn) ==
19-
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
20-
@test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000]
21-
@test Tables.getcolumn(chn, :chain) ==
22-
[fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)]
23-
@test Tables.getcolumn(chn, :a) == [
24-
vec(chn[:, :a, 1])
25-
vec(chn[:, :a, 2])
26-
vec(chn[:, :a, 3])
27-
vec(chn[:, :a, 4])
28-
]
29-
@test_throws Exception Tables.getcolumn(chn, :j)
30-
@test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration)
31-
@test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain)
32-
@test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a)
33-
@test_throws Exception Tables.getcolumn(chn, :i)
34-
@test_throws Exception Tables.getcolumn(chn, 11)
35-
@test Tables.rowaccess(typeof(chn))
36-
@test Tables.rows(chn) === chn
37-
@test length(Tables.rowtable(chn)) == 4000
38-
nt = Tables.rowtable(chn)[1]
39-
@test nt ==
40-
(; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...)
41-
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1]
42-
nt = Tables.rowtable(chn)[2]
43-
@test nt ==
44-
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)
45-
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2]
46-
@test Tables.schema(chn) isa Tables.Schema
47-
@test Tables.schema(chn).names ===
48-
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
49-
@test Tables.schema(chn).types === (
50-
Int,
51-
Int,
52-
Float64,
53-
Float64,
54-
Float64,
55-
Float64,
56-
Float64,
57-
Float64,
58-
Float64,
59-
Float64,
60-
)
61-
@test Tables.matrix(chn[:, :, 1])[:, 3:end] chn[:, :, 1].value
62-
@test Tables.matrix(chn[:, :, 2])[:, 3:end] chn[:, :, 2].value
63-
64-
val = rand(1000, 2, 4)
65-
chn2 = Chains(val, ["iteration", "a"])
66-
@test_throws Exception Tables.columns(chn2)
67-
@test_throws Exception Tables.rows(chn2)
68-
@test_throws Exception Tables.schema(chn2)
69-
chn3 = Chains(val, ["chain", "a"])
70-
@test_throws Exception Tables.columns(chn3)
71-
@test_throws Exception Tables.rows(chn3)
72-
@test_throws Exception Tables.schema(chn3)
16+
17+
@testset "column access" begin
18+
@test Tables.columnaccess(typeof(chn))
19+
@test Tables.columns(chn) === chn
20+
@test Tables.columnnames(chn) ==
21+
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
22+
@test Tables.getcolumn(chn, :iteration) == [1:1000; 1:1000; 1:1000; 1:1000]
23+
@test Tables.getcolumn(chn, :chain) ==
24+
[fill(1, 1000); fill(2, 1000); fill(3, 1000); fill(4, 1000)]
25+
@test Tables.getcolumn(chn, :a) == [
26+
vec(chn[:, :a, 1])
27+
vec(chn[:, :a, 2])
28+
vec(chn[:, :a, 3])
29+
vec(chn[:, :a, 4])
30+
]
31+
@test_throws Exception Tables.getcolumn(chn, :j)
32+
@test Tables.getcolumn(chn, 1) == Tables.getcolumn(chn, :iteration)
33+
@test Tables.getcolumn(chn, 2) == Tables.getcolumn(chn, :chain)
34+
@test Tables.getcolumn(chn, 3) == Tables.getcolumn(chn, :a)
35+
@test_throws Exception Tables.getcolumn(chn, :i)
36+
@test_throws Exception Tables.getcolumn(chn, 11)
37+
end
38+
39+
@testset "row access" begin
40+
@test Tables.rowaccess(typeof(chn))
41+
@test Tables.rows(chn) isa Tables.RowIterator
42+
@test eltype(Tables.rows(chn)) <: Tables.AbstractRow
43+
rows = collect(Tables.rows(chn))
44+
@test eltype(rows) <: Tables.AbstractRow
45+
@test size(rows) === (4000,)
46+
for chainid in 1:4, iterid in 1:1000
47+
row = rows[(chainid - 1) * 1000 + iterid]
48+
@test Tables.columnnames(row) ==
49+
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
50+
@test Tables.getcolumn(row, 1) == iterid
51+
@test Tables.getcolumn(row, 2) == chainid
52+
@test Tables.getcolumn(row, 3) == chn[iterid, :a, chainid]
53+
@test Tables.getcolumn(row, 10) == chn[iterid, :h, chainid]
54+
@test Tables.getcolumn(row, :iteration) == iterid
55+
@test Tables.getcolumn(row, :chain) == chainid
56+
@test Tables.getcolumn(row, :a) == chn[iterid, :a, chainid]
57+
@test Tables.getcolumn(row, :h) == chn[iterid, :h, chainid]
58+
end
59+
end
60+
61+
@testset "integration tests" begin
62+
@test length(Tables.rowtable(chn)) == 4000
63+
nt = Tables.rowtable(chn)[1]
64+
@test nt ==
65+
(; (k => Tables.getcolumn(chn, k)[1] for k in Tables.columnnames(chn))...)
66+
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 1))[1]
67+
nt = Tables.rowtable(chn)[2]
68+
@test nt ==
69+
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)
70+
@test nt == collect(Iterators.take(Tables.namedtupleiterator(chn), 2))[2]
71+
@test Tables.matrix(chn[:, :, 1])[:, 3:end] chn[:, :, 1].value
72+
@test Tables.matrix(chn[:, :, 2])[:, 3:end] chn[:, :, 2].value
73+
@test Tables.matrix(Tables.rowtable(chn)) == Tables.matrix(Tables.columntable(chn))
74+
end
75+
76+
@testset "schema" begin
77+
@test Tables.schema(chn) isa Tables.Schema
78+
@test Tables.schema(chn).names ===
79+
(:iteration, :chain, :a, :b, :c, :d, :e, :f, :g, :h)
80+
@test Tables.schema(chn).types === (
81+
Int,
82+
Int,
83+
Float64,
84+
Float64,
85+
Float64,
86+
Float64,
87+
Float64,
88+
Float64,
89+
Float64,
90+
Float64,
91+
)
92+
end
93+
94+
@testset "exceptions raised if reserved colname used" begin
95+
val2 = rand(1000, 2, 4)
96+
chn2 = Chains(val2, ["iteration", "a"])
97+
@test_throws Exception Tables.columns(chn2)
98+
@test_throws Exception Tables.rows(chn2)
99+
@test_throws Exception Tables.schema(chn2)
100+
chn3 = Chains(val2, ["chain", "a"])
101+
@test_throws Exception Tables.columns(chn3)
102+
@test_throws Exception Tables.rows(chn3)
103+
@test_throws Exception Tables.schema(chn3)
104+
end
73105
end
74106

75107
@testset "TableTraits interface" begin
@@ -82,10 +114,10 @@ using DataFrames
82114
@test nt ==
83115
(; (k => Tables.getcolumn(chn, k)[2] for k in Tables.columnnames(chn))...)
84116

85-
val = rand(1000, 2, 4)
86-
chn2 = Chains(val, ["iteration", "a"])
117+
val2 = rand(1000, 2, 4)
118+
chn2 = Chains(val2, ["iteration", "a"])
87119
@test_throws Exception IteratorInterfaceExtensions.getiterator(chn2)
88-
chn3 = Chains(val, ["chain", "a"])
120+
chn3 = Chains(val2, ["chain", "a"])
89121
@test_throws Exception IteratorInterfaceExtensions.getiterator(chn3)
90122
end
91123

@@ -106,29 +138,54 @@ using DataFrames
106138

107139
@testset "Tables interface" begin
108140
@test Tables.istable(typeof(cdf))
109-
@test Tables.columnaccess(typeof(cdf))
110-
@test Tables.columns(cdf) === cdf
111-
@test Tables.columnnames(cdf) == keys(cdf.nt)
112-
for (k, v) in pairs(cdf.nt)
113-
@test Tables.getcolumn(cdf, k) == v
141+
142+
@testset "column access" begin
143+
@test Tables.columnaccess(typeof(cdf))
144+
@test Tables.columns(cdf) === cdf
145+
@test Tables.columnnames(cdf) == keys(cdf.nt)
146+
for (k, v) in pairs(cdf.nt)
147+
@test Tables.getcolumn(cdf, k) == v
148+
end
149+
@test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1])
150+
@test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2])
151+
@test_throws Exception Tables.getcolumn(cdf, :blah)
152+
@test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1)
153+
end
154+
155+
@testset "row access" begin
156+
@test Tables.rowaccess(typeof(cdf))
157+
@test Tables.rows(cdf) isa Tables.RowIterator
158+
@test eltype(Tables.rows(cdf)) <: Tables.AbstractRow
159+
rows = collect(Tables.rows(cdf))
160+
@test eltype(rows) <: Tables.AbstractRow
161+
@test size(rows) === (2,)
162+
@testset for i in 1:2
163+
row = rows[i]
164+
@test Tables.columnnames(row) == keys(cdf.nt)
165+
for j in length(cdf.nt)
166+
@test Tables.getcolumn(row, j) == cdf.nt[j][i]
167+
@test Tables.getcolumn(row, keys(cdf.nt)[j]) == cdf.nt[j][i]
168+
end
169+
end
170+
end
171+
172+
@testset "integration tests" begin
173+
@test length(Tables.rowtable(cdf)) == length(cdf.nt[1])
174+
@test Tables.columntable(cdf) == cdf.nt
175+
nt = Tables.rowtable(cdf)[1]
176+
@test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...)
177+
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]
178+
nt = Tables.rowtable(cdf)[2]
179+
@test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...)
180+
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]
181+
@test Tables.matrix(Tables.rowtable(cdf)) == Tables.matrix(Tables.columntable(cdf))
182+
end
183+
184+
@testset "schema" begin
185+
@test Tables.schema(cdf) isa Tables.Schema
186+
@test Tables.schema(cdf).names === keys(cdf.nt)
187+
@test Tables.schema(cdf).types === eltype.(values(cdf.nt))
114188
end
115-
@test Tables.getcolumn(cdf, 1) == Tables.getcolumn(cdf, keys(cdf.nt)[1])
116-
@test Tables.getcolumn(cdf, 2) == Tables.getcolumn(cdf, keys(cdf.nt)[2])
117-
@test_throws Exception Tables.getcolumn(cdf, :blah)
118-
@test_throws Exception Tables.getcolumn(cdf, length(cdf.nt) + 1)
119-
@test Tables.rowaccess(typeof(cdf))
120-
@test Tables.rows(cdf) === cdf
121-
@test length(Tables.rowtable(cdf)) == length(cdf.nt[1])
122-
@test Tables.columntable(cdf) == cdf.nt
123-
nt = Tables.rowtable(cdf)[1]
124-
@test nt == (; (k => v[1] for (k, v) in pairs(cdf.nt))...)
125-
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 1))[1]
126-
nt = Tables.rowtable(cdf)[2]
127-
@test nt == (; (k => v[2] for (k, v) in pairs(cdf.nt))...)
128-
@test nt == collect(Iterators.take(Tables.namedtupleiterator(cdf), 2))[2]
129-
@test Tables.schema(cdf) isa Tables.Schema
130-
@test Tables.schema(cdf).names === keys(cdf.nt)
131-
@test Tables.schema(cdf).types === eltype.(values(cdf.nt))
132189
end
133190

134191
@testset "TableTraits interface" begin

0 commit comments

Comments
 (0)