Skip to content

Commit 1aa5c36

Browse files
Merge pull request #1891 from CliMA/ck/struct_tests
Add `get_struct` unit tests
2 parents f248a6e + a021ea0 commit 1aa5c36

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

test/DataLayouts/unit_struct.jl

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#=
2+
julia --check-bounds=yes --project
3+
using Revise; include(joinpath("test", "DataLayouts", "unit_struct.jl"))
4+
=#
5+
using Test
6+
using ClimaCore.DataLayouts
7+
using StaticArrays
8+
9+
function one_to_n(a::Array)
10+
for i in 1:length(a)
11+
a[i] = i
12+
end
13+
return a
14+
end
15+
one_to_n(s::Tuple, ::Type{FT}) where {FT} = one_to_n(zeros(FT, s...))
16+
ncomponents(::Type{FT}, ::Type{S}) where {FT, S} = div(sizeof(S), sizeof(FT))
17+
18+
function test_get_struct(::Type{FT}, ::Type{S}) where {FT, S}
19+
s = (2,)
20+
a = one_to_n(s, FT)
21+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
22+
for (i, ci) in enumerate(CI)
23+
for j in 1:length(s)
24+
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
25+
end
26+
end
27+
28+
s = (2, 3)
29+
a = one_to_n(s, FT)
30+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
31+
for (i, ci) in enumerate(CI)
32+
for j in 1:length(s)
33+
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
34+
end
35+
end
36+
37+
s = (2, 3, 4)
38+
a = one_to_n(s, FT)
39+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
40+
for (i, ci) in enumerate(CI)
41+
for j in 1:length(s)
42+
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
43+
end
44+
end
45+
46+
s = (2, 3, 4, 5)
47+
a = one_to_n(s, FT)
48+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
49+
for (i, ci) in enumerate(CI)
50+
for j in 1:length(s)
51+
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
52+
end
53+
end
54+
end
55+
56+
@testset "get_struct - Float" begin
57+
test_get_struct(Float64, Float64)
58+
test_get_struct(Float32, Float32)
59+
end
60+
61+
struct Foo{T}
62+
x::T
63+
y::T
64+
end
65+
66+
Base.zero(::Type{Foo{T}}) where {T} = Foo{T}(0, 0)
67+
68+
@testset "get_struct - flat struct 2-fields 1-dim" begin
69+
FT = Float64
70+
S = Foo{FT}
71+
s = (4,)
72+
a = one_to_n(s, FT)
73+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
74+
@test ncomponents(FT, S) == 2
75+
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
76+
@test DataLayouts.get_struct(a, S, Val(1), CI[2]) == Foo{FT}(2.0, 3.0)
77+
@test DataLayouts.get_struct(a, S, Val(1), CI[3]) == Foo{FT}(3.0, 4.0)
78+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[4])
79+
end
80+
81+
@testset "get_struct - flat struct 2-fields 3-dims" begin
82+
FT = Float64
83+
S = Foo{FT}
84+
s = (2, 3, 4)
85+
a = one_to_n(s, FT)
86+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
87+
@test ncomponents(FT, S) == 2
88+
89+
# Call get_struct, and span `a` (access elements to 24.0):
90+
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
91+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[2])
92+
93+
@test DataLayouts.get_struct(a, S, Val(2), CI[1]) == Foo{FT}(1.0, 3.0)
94+
@test DataLayouts.get_struct(a, S, Val(2), CI[2]) == Foo{FT}(2.0, 4.0)
95+
@test DataLayouts.get_struct(a, S, Val(2), CI[3]) == Foo{FT}(3.0, 5.0)
96+
@test DataLayouts.get_struct(a, S, Val(2), CI[4]) == Foo{FT}(4.0, 6.0)
97+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[5])
98+
99+
@test DataLayouts.get_struct(a, S, Val(3), CI[1]) == Foo{FT}(1.0, 7.0)
100+
@test DataLayouts.get_struct(a, S, Val(3), CI[2]) == Foo{FT}(2.0, 8.0)
101+
@test DataLayouts.get_struct(a, S, Val(3), CI[3]) == Foo{FT}(3.0, 9.0)
102+
@test DataLayouts.get_struct(a, S, Val(3), CI[4]) == Foo{FT}(4.0, 10.0)
103+
@test DataLayouts.get_struct(a, S, Val(3), CI[5]) == Foo{FT}(5.0, 11.0)
104+
@test DataLayouts.get_struct(a, S, Val(3), CI[6]) == Foo{FT}(6.0, 12.0)
105+
@test DataLayouts.get_struct(a, S, Val(3), CI[7]) == Foo{FT}(7.0, 13.0)
106+
@test DataLayouts.get_struct(a, S, Val(3), CI[8]) == Foo{FT}(8.0, 14.0)
107+
@test DataLayouts.get_struct(a, S, Val(3), CI[9]) == Foo{FT}(9.0, 15.0)
108+
@test DataLayouts.get_struct(a, S, Val(3), CI[10]) == Foo{FT}(10.0, 16.0)
109+
@test DataLayouts.get_struct(a, S, Val(3), CI[11]) == Foo{FT}(11.0, 17.0)
110+
@test DataLayouts.get_struct(a, S, Val(3), CI[12]) == Foo{FT}(12.0, 18.0)
111+
@test DataLayouts.get_struct(a, S, Val(3), CI[13]) == Foo{FT}(13.0, 19.0)
112+
@test DataLayouts.get_struct(a, S, Val(3), CI[14]) == Foo{FT}(14.0, 20.0)
113+
@test DataLayouts.get_struct(a, S, Val(3), CI[15]) == Foo{FT}(15.0, 21.0)
114+
@test DataLayouts.get_struct(a, S, Val(3), CI[16]) == Foo{FT}(16.0, 22.0)
115+
@test DataLayouts.get_struct(a, S, Val(3), CI[17]) == Foo{FT}(17.0, 23.0)
116+
@test DataLayouts.get_struct(a, S, Val(3), CI[18]) == Foo{FT}(18.0, 24.0)
117+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(3), CI[19])
118+
end
119+
120+
@testset "get_struct - flat struct 2-fields 5-dims" begin
121+
FT = Float64
122+
S = Foo{FT}
123+
s = (2, 2, 2, 2, 2)
124+
a = one_to_n(s, FT)
125+
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
126+
@test ncomponents(FT, S) == 2
127+
128+
# Call get_struct, and span `a` (access elements to 2^5 = 32.0):
129+
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
130+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[2])
131+
132+
@test DataLayouts.get_struct(a, S, Val(2), CI[1]) == Foo{FT}(1.0, 3.0)
133+
@test DataLayouts.get_struct(a, S, Val(2), CI[2]) == Foo{FT}(2.0, 4.0)
134+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[3])
135+
136+
@test DataLayouts.get_struct(a, S, Val(3), CI[1]) == Foo{FT}(1.0, 5.0)
137+
@test DataLayouts.get_struct(a, S, Val(3), CI[2]) == Foo{FT}(2.0, 6.0)
138+
@test DataLayouts.get_struct(a, S, Val(3), CI[3]) == Foo{FT}(3.0, 7.0)
139+
@test DataLayouts.get_struct(a, S, Val(3), CI[4]) == Foo{FT}(4.0, 8.0)
140+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(3), CI[5])
141+
142+
# VIJFH
143+
@test DataLayouts.get_struct(a, S, Val(4), CI[1]) == Foo{FT}(1.0, 9.0)
144+
@test DataLayouts.get_struct(a, S, Val(4), CI[2]) == Foo{FT}(2.0, 10.0)
145+
@test DataLayouts.get_struct(a, S, Val(4), CI[3]) == Foo{FT}(3.0, 11.0)
146+
@test DataLayouts.get_struct(a, S, Val(4), CI[4]) == Foo{FT}(4.0, 12.0)
147+
@test DataLayouts.get_struct(a, S, Val(4), CI[5]) == Foo{FT}(5.0, 13.0)
148+
@test DataLayouts.get_struct(a, S, Val(4), CI[6]) == Foo{FT}(6.0, 14.0)
149+
@test DataLayouts.get_struct(a, S, Val(4), CI[7]) == Foo{FT}(7.0, 15.0)
150+
@test DataLayouts.get_struct(a, S, Val(4), CI[8]) == Foo{FT}(8.0, 16.0)
151+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[9])
152+
153+
@test DataLayouts.get_struct(a, S, Val(5), CI[1]) == Foo{FT}(1.0, 17.0)
154+
@test DataLayouts.get_struct(a, S, Val(5), CI[2]) == Foo{FT}(2.0, 18.0)
155+
@test DataLayouts.get_struct(a, S, Val(5), CI[3]) == Foo{FT}(3.0, 19.0)
156+
@test DataLayouts.get_struct(a, S, Val(5), CI[4]) == Foo{FT}(4.0, 20.0)
157+
@test DataLayouts.get_struct(a, S, Val(5), CI[5]) == Foo{FT}(5.0, 21.0)
158+
@test DataLayouts.get_struct(a, S, Val(5), CI[6]) == Foo{FT}(6.0, 22.0)
159+
@test DataLayouts.get_struct(a, S, Val(5), CI[7]) == Foo{FT}(7.0, 23.0)
160+
@test DataLayouts.get_struct(a, S, Val(5), CI[8]) == Foo{FT}(8.0, 24.0)
161+
@test DataLayouts.get_struct(a, S, Val(5), CI[9]) == Foo{FT}(9.0, 25.0)
162+
@test DataLayouts.get_struct(a, S, Val(5), CI[10]) == Foo{FT}(10.0, 26.0)
163+
@test DataLayouts.get_struct(a, S, Val(5), CI[11]) == Foo{FT}(11.0, 27.0)
164+
@test DataLayouts.get_struct(a, S, Val(5), CI[12]) == Foo{FT}(12.0, 28.0)
165+
@test DataLayouts.get_struct(a, S, Val(5), CI[13]) == Foo{FT}(13.0, 29.0)
166+
@test DataLayouts.get_struct(a, S, Val(5), CI[14]) == Foo{FT}(14.0, 30.0)
167+
@test DataLayouts.get_struct(a, S, Val(5), CI[15]) == Foo{FT}(15.0, 31.0)
168+
@test DataLayouts.get_struct(a, S, Val(5), CI[16]) == Foo{FT}(16.0, 32.0)
169+
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(5), CI[17])
170+
end
171+
172+
# TODO: add set_struct!

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ include("tabulated_tests.jl")
99
unit_tests = [
1010
UnitTest("DataLayouts fill" ,"DataLayouts/unit_fill.jl"),
1111
UnitTest("DataLayouts ndims" ,"DataLayouts/unit_ndims.jl"),
12+
UnitTest("DataLayouts get_struct" ,"DataLayouts/unit_struct.jl"),
1213
UnitTest("Recursive" ,"RecursiveApply/unit_recursive_apply.jl"),
1314
UnitTest("PlusHalf" ,"Utilities/unit_plushalf.jl"),
1415
UnitTest("DataLayouts 0D" ,"DataLayouts/data0d.jl"),

0 commit comments

Comments
 (0)