Skip to content

Commit 38ff711

Browse files
Add dss_transform unit test
1 parent d9c8f4c commit 38ff711

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

test/Spaces/unit_dss.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#=
2+
julia --project
3+
ENV["CLIMACOMMS_DEVICE"] = "CPU";
4+
using Revise; include(joinpath("test", "Spaces", "unit_dss.jl"))
5+
=#
6+
using Test
7+
using ClimaComms
8+
using Random
9+
ClimaComms.@import_required_backends
10+
11+
import ClimaCore:
12+
Domains,
13+
Fields,
14+
Geometry,
15+
Meshes,
16+
Operators,
17+
Spaces,
18+
Quadratures,
19+
Topologies,
20+
DataLayouts
21+
22+
function get_space_cs(::Type{FT}; context, R = 300.0) where {FT}
23+
domain = Domains.SphereDomain{FT}(300.0)
24+
mesh = Meshes.EquiangularCubedSphere(domain, 3)
25+
topology = Topologies.Topology2D(context, mesh)
26+
quad = Quadratures.GLL{4}()
27+
space = Spaces.SpectralElementSpace2D(topology, quad)
28+
return space
29+
end
30+
31+
function one_to_n(a::AbstractArray)
32+
_a = Array(a)
33+
Random.seed!(1234)
34+
for i in 1:length(_a)
35+
_a[i] = rand()
36+
end
37+
return typeof(a)(_a)
38+
end
39+
40+
function test_dss_count(f::Fields.Field, buff::Topologies.DSSBuffer, nc)
41+
parent(f) .= one_to_n(parent(f))
42+
@test allunique(parent(f))
43+
cf = copy(f)
44+
Spaces.weighted_dss!(f => buff)
45+
n_dss_unaffected = count(parent(f) .== parent(cf))
46+
n_dss_affected = length(parent(f)) - n_dss_unaffected
47+
return (; n_dss_affected)
48+
end
49+
50+
function get_space_and_buffers(::Type{FT}; context) where {FT}
51+
init_state_covariant12(local_geometry, p) =
52+
Geometry.Covariant12Vector(1.0, -1.0)
53+
init_state_covariant123(local_geometry, p) =
54+
Geometry.Covariant123Vector(1.0, -1.0, 1.0)
55+
init_state_covariant3(local_geometry, p) = Geometry.Covariant3Vector(1.0)
56+
57+
R = FT(6.371229e6)
58+
npoly = 2
59+
z_max = FT(30e3)
60+
z_elem = 3
61+
h_elem = 2
62+
device = ClimaComms.device(context)
63+
@info "running dss-Covariant123Vector test on $(device)" h_elem z_elem npoly R z_max FT
64+
# horizontal space
65+
domain = Domains.SphereDomain{FT}(R)
66+
horizontal_mesh = Meshes.EquiangularCubedSphere(domain, h_elem)
67+
horizontal_topology = Topologies.Topology2D(
68+
context,
69+
horizontal_mesh,
70+
Topologies.spacefillingcurve(horizontal_mesh),
71+
)
72+
quad = Quadratures.GLL{npoly + 1}()
73+
h_space = Spaces.SpectralElementSpace2D(horizontal_topology, quad)
74+
# vertical space
75+
z_domain = Domains.IntervalDomain(
76+
Geometry.ZPoint{FT}(zero(z_max)),
77+
Geometry.ZPoint{FT}(z_max);
78+
boundary_names = (:bottom, :top),
79+
)
80+
z_mesh = Meshes.IntervalMesh(z_domain, nelems = z_elem)
81+
z_topology = Topologies.IntervalTopology(context, z_mesh)
82+
z_center_space = Spaces.CenterFiniteDifferenceSpace(z_topology)
83+
space = Spaces.ExtrudedFiniteDifferenceSpace(h_space, z_center_space)
84+
args = (Fields.local_geometry_field(space), Ref(nothing))
85+
y12 = init_state_covariant12.(args...)
86+
y123 = init_state_covariant123.(args...)
87+
y3 = init_state_covariant3.(args...)
88+
dss_buffer = (;
89+
y12 = Spaces.create_dss_buffer(y12),
90+
y123 = Spaces.create_dss_buffer(y123),
91+
y3 = Spaces.create_dss_buffer(y3),
92+
)
93+
return (; space, y12, y123, y3, dss_buffer)
94+
end
95+
96+
@testset "DSS of AxisTensors on Cubed Sphere" begin
97+
FT = Float64
98+
device = ClimaComms.device()
99+
nt = get_space_and_buffers(FT; context = ClimaComms.context(device))
100+
101+
# test DSS for a Covariant12Vector
102+
# ensure physical velocity is continuous across SE boundary for initial state
103+
n_dss_affected_y12 =
104+
test_dss_count(nt.y12, nt.dss_buffer.y12, 2).n_dss_affected
105+
n_dss_affected_y123 =
106+
test_dss_count(nt.y123, nt.dss_buffer.y123, 3).n_dss_affected
107+
n_dss_affected_y3 =
108+
test_dss_count(nt.y3, nt.dss_buffer.y3, 1).n_dss_affected
109+
110+
@test n_dss_affected_y12 * 3 / 2 ==
111+
n_dss_affected_y123 ==
112+
n_dss_affected_y3 * 3
113+
114+
@test nt.dss_buffer.y12.scalarfidx == Int[]
115+
@test nt.dss_buffer.y12.covariant12fidx == Int[1]
116+
@test nt.dss_buffer.y12.contravariant12fidx == Int[]
117+
@test nt.dss_buffer.y12.covariant123fidx == Int[]
118+
@test nt.dss_buffer.y12.contravariant123fidx == Int[]
119+
120+
@test nt.dss_buffer.y123.scalarfidx == Int[]
121+
@test nt.dss_buffer.y123.covariant12fidx == Int[]
122+
@test nt.dss_buffer.y123.contravariant12fidx == Int[]
123+
@test nt.dss_buffer.y123.covariant123fidx == Int[1]
124+
@test nt.dss_buffer.y123.contravariant123fidx == Int[]
125+
126+
@test nt.dss_buffer.y3.scalarfidx == Int[1]
127+
@test nt.dss_buffer.y3.covariant12fidx == Int[]
128+
@test nt.dss_buffer.y3.contravariant12fidx == Int[]
129+
@test nt.dss_buffer.y3.covariant123fidx == Int[]
130+
@test nt.dss_buffer.y3.contravariant123fidx == Int[]
131+
end

test/Topologies/unit_dss_transform.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#=
2+
julia --project
3+
using Revise; include(joinpath("test", "Topologies", "unit_dss_transform.jl"))
4+
=#
5+
using Test
6+
using ClimaComms
7+
using ClimaCore.Topologies: dss_transform, dss_untransform
8+
using Random
9+
ClimaComms.@import_required_backends
10+
11+
import ClimaCore:
12+
Domains,
13+
Fields,
14+
Geometry,
15+
Meshes,
16+
Operators,
17+
Spaces,
18+
Quadratures,
19+
Topologies,
20+
DataLayouts
21+
22+
function get_space(::Type{FT}; context) where {FT}
23+
R = FT(6.371229e6)
24+
npoly = 2
25+
z_max = FT(30e3)
26+
z_elem = 3
27+
h_elem = 2
28+
device = ClimaComms.device(context)
29+
@info "running dss-Covariant123Vector test on $(device)" h_elem z_elem npoly R z_max FT
30+
# horizontal space
31+
domain = Domains.SphereDomain{FT}(R)
32+
horizontal_mesh = Meshes.EquiangularCubedSphere(domain, h_elem)
33+
horizontal_topology = Topologies.Topology2D(
34+
context,
35+
horizontal_mesh,
36+
Topologies.spacefillingcurve(horizontal_mesh),
37+
)
38+
quad = Quadratures.GLL{npoly + 1}()
39+
h_space = Spaces.SpectralElementSpace2D(horizontal_topology, quad)
40+
# vertical space
41+
z_domain = Domains.IntervalDomain(
42+
Geometry.ZPoint{FT}(zero(z_max)),
43+
Geometry.ZPoint{FT}(z_max);
44+
boundary_names = (:bottom, :top),
45+
)
46+
z_mesh = Meshes.IntervalMesh(z_domain, nelems = z_elem)
47+
z_topology = Topologies.IntervalTopology(context, z_mesh)
48+
z_center_space = Spaces.CenterFiniteDifferenceSpace(z_topology)
49+
space = Spaces.ExtrudedFiniteDifferenceSpace(h_space, z_center_space)
50+
return space
51+
end
52+
53+
@testset "dss_transform" begin
54+
device = ClimaComms.device()
55+
space = get_space(Float64; context = ClimaComms.context(device))
56+
57+
local_geometry = Fields.local_geometry_field(space)
58+
map(local_geometry) do lg
59+
FT = Geometry.undertype(typeof(lg))
60+
(; lat, long, z) = lg.coordinates
61+
# Test that vertical component is treated as a scalar:
62+
63+
arg = Geometry.Covariant123Vector(FT(lat), FT(long), FT(z))
64+
weight = 2
65+
dss_t = dss_transform(arg, lg, weight)
66+
dss_ut = dss_untransform(Geometry.Covariant123Vector{FT}, dss_t, lg)
67+
@test dss_t isa Geometry.UVWVector
68+
@test typeof(arg) == typeof(dss_ut)
69+
@test arg dss_ut / weight
70+
71+
arg = Geometry.Covariant12Vector(FT(lat), FT(long))
72+
weight = 2
73+
dss_t = dss_transform(arg, lg, weight)
74+
dss_ut = dss_untransform(Geometry.Covariant12Vector{FT}, dss_t, lg)
75+
@test dss_t isa Geometry.UVWVector
76+
@test typeof(arg) == typeof(dss_ut)
77+
@test arg dss_ut / weight
78+
79+
arg = Geometry.Covariant3Vector(FT(z))
80+
weight = 2
81+
dss_t = dss_transform(arg, lg, weight)
82+
dss_ut = dss_untransform(Geometry.Covariant3Vector{FT}, dss_t, lg)
83+
@test typeof(arg) == typeof(dss_ut)
84+
@test dss_t isa Geometry.Covariant3Vector
85+
@test dss_t === arg * weight
86+
@test arg == dss_ut / weight
87+
FT(1)
88+
end
89+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ UnitTest("Rectangle topology" ,"Topologies/rectangle.jl"),
3030
UnitTest("Rectangle surface topology" ,"Topologies/rectangle_sfc.jl"),
3131
UnitTest("Cubedsphere topology" ,"Topologies/cubedsphere.jl"),
3232
UnitTest("Cubedsphere surface topology" ,"Topologies/cubedsphere_sfc.jl"),
33+
UnitTest("dss_transform" ,"Topologies/unit_dss_transform.jl"),
3334
UnitTest("Quadratures" ,"Quadratures/Quadratures.jl"),
3435
UnitTest("Spaces" ,"Spaces/unit_spaces.jl"),
36+
UnitTest("dss" ,"Spaces/unit_dss.jl"),
3537
UnitTest("Spaces - serial CPU DSS" ,"Spaces/ddss1.jl"),
3638
UnitTest("Spaces - DSS cubed sphere" ,"Spaces/ddss1_cs.jl"),
3739
UnitTest("Sphere spaces" ,"Spaces/sphere.jl"),

0 commit comments

Comments
 (0)