Skip to content

Commit 59fe268

Browse files
wip
1 parent 6fa5586 commit 59fe268

File tree

8 files changed

+40
-17
lines changed

8 files changed

+40
-17
lines changed

src/Geometry/axistensors.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ struct AxisTensor{
137137
N,
138138
A <: NTuple{N, AbstractAxis},
139139
S <: Union{
140-
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
140+
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
141+
SimpleSymmetric{N, T},
141142
StaticArray{<:Tuple, T, N},
142143
},
143144
} <: AbstractArray{T, N}
@@ -151,7 +152,8 @@ AxisTensor(
151152
) where {
152153
A <: Tuple{Vararg{AbstractAxis}},
153154
S <: Union{
154-
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
155+
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
156+
SimpleSymmetric{N, T},
155157
StaticArray{<:Tuple, T, N},
156158
},
157159
} where {T, N} = AxisTensor{T, N, A, S}(axes, components)

src/Geometry/localgeometry.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818
1919
The necessary local metric information defined at each node.
2020
"""
21-
struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
21+
struct LocalGeometry{I, C <: AbstractPoint, FT, S, N, L}
2222
"Coordinates of the current point"
2323
coordinates::C
2424
"Jacobian determinant of the transformation `ξ` to `x`"
@@ -35,15 +35,13 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
3535
gⁱʲ::Axis2Tensor{
3636
FT,
3737
Tuple{ContravariantAxis{I}, ContravariantAxis{I}},
38-
# SimpleSymmetric{FT, S},
39-
SimpleSymmetric{2, FT, L},
38+
SimpleSymmetric{N, FT, L},
4039
}
4140
"Covariant metric tensor (gᵢⱼ), transforms contravariant to covariant vector components"
4241
gᵢⱼ::Axis2Tensor{
4342
FT,
4443
Tuple{CovariantAxis{I}, CovariantAxis{I}},
45-
# SimpleSymmetric{FT, S},
46-
SimpleSymmetric{2, FT, L},
44+
SimpleSymmetric{N, FT, L},
4745
}
4846
@inline function LocalGeometry(
4947
coordinates,
@@ -61,7 +59,8 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
6159
gⁱʲ = SimpleSymmetric(gⁱʲ₀)
6260
gᵢⱼ = SimpleSymmetric(gᵢⱼ₀)
6361
L = triangular_nonzeros(S)
64-
return new{I, C, FT, S, L}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ)
62+
N = size(components(gⁱʲ₀), 1)
63+
return new{I, C, FT, S, N, L}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ)
6564
end
6665
end
6766

@@ -77,7 +76,7 @@ struct SurfaceGeometry{FT, N}
7776
normal::N
7877
end
7978

80-
undertype(::Type{LocalGeometry{I, C, FT, S}}) where {I, C, FT, S} = FT
79+
undertype(::Type{<:LocalGeometry{I, C, FT}}) where {I, C, FT} = FT
8180
undertype(::Type{SurfaceGeometry{FT, N}}) where {FT, N} = FT
8281

8382
"""

src/Geometry/simple_symmetric.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,14 @@ StaticArrays.check_parameters(
5353

5454
triangular_nonzeros(::SMatrix{N}) where {N} = Int(N * (N + 1) / 2)
5555
triangular_nonzeros(::Type{<:SMatrix{N}}) where {N} = Int(N * (N + 1) / 2)
56+
tail_params(::Type{S}) where {N,T, S<:SMatrix{N,N,T}} = (T, S, N, triangular_nonzeros(S))
5657

5758
function SimpleSymmetric(A::SMatrix)
5859
@assert size(A, 1) == size(A, 2)
59-
M = size(A, 1)
60-
N = ndims(A)
60+
N = size(A, 1)
61+
nd = ndims(A)
6162
T = eltype(A)
62-
ci = ntuple(i -> 1:size(A, i), N)
63+
ci = ntuple(i -> 1:size(A, i), nd)
6364
upper_inds = filter(I -> I.I[2] I.I[1], CartesianIndices(ci))
6465
L = triangular_nonzeros(A)
6566
upper_triang = SVector{L}(map(I -> A[I], upper_inds))

src/Grids/finitedifference.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ function fd_geometry_data(
7979
) where {FT, periodic}
8080
CT = Geometry.ZPoint{FT}
8181
AIdx = (3,)
82-
LG = Geometry.LocalGeometry{AIdx, CT, FT, SMatrix{1, 1, FT, 1}}
82+
S = SMatrix{1, 1, FT, 1}
83+
LG = Geometry.LocalGeometry{AIdx, CT, Geometry.tail_params(S)...}
8384
(Ni, Nj, Nk, Nv, Nh) = size(face_coordinates)
8485
Nv_face = Nv - periodic
8586
Nv_cent = Nv - 1

src/Grids/spectralelement.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ function _SpectralElementGrid1D(
5050
nelements = Topologies.nlocalelems(topology)
5151
Nq = Quadratures.degrees_of_freedom(quadrature_style)
5252

53-
LG = Geometry.LocalGeometry{AIdx, CoordType, FT, SMatrix{1, 1, FT, 1}}
53+
S = SMatrix{1, 1, FT, 1}
54+
LG = Geometry.LocalGeometry{AIdx, CoordType, Geometry.tail_params(S)...}
5455
local_geometry = DataLayouts.IFH{LG, Nq}(Array{FT}, nelements)
5556
quad_points, quad_weights =
5657
Quadratures.quadrature_points(FT, quadrature_style)
@@ -218,7 +219,8 @@ function _SpectralElementGrid2D(
218219
high_order_quadrature_style = Quadratures.GLL{Nq * 2}()
219220
high_order_Nq = Quadratures.degrees_of_freedom(high_order_quadrature_style)
220221

221-
LG = Geometry.LocalGeometry{AIdx, CoordType2D, FT, SMatrix{2, 2, FT, 4}}
222+
S = SMatrix{2, 2, FT, 4}
223+
LG = Geometry.LocalGeometry{AIdx, CoordType2D, Geometry.tail_params(S)...}
222224

223225
local_geometry = DataLayouts.IJFH{LG, Nq}(Array{FT}, nlelems)
224226

test/DataLayouts/opt_similar.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function test_similar!(data)
1616
FT = eltype(parent(data))
1717
CT = Geometry.ZPoint{FT}
1818
AIdx = (3,)
19-
LG = Geometry.LocalGeometry{AIdx, CT, FT, SMatrix{1, 1, FT, 1}}
19+
S = SMatrix{1, 1, FT, 1}
20+
LG = Geometry.LocalGeometry{AIdx, CT, Geometry.tail_params(S)...}
2021
(_, _, _, Nv, _) = size(data)
2122
similar(data, LG, Val(Nv))
2223
@test_opt similar(data, LG, Val(Nv))

test/Geometry/axistensor_conversion_benchmarks.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ function benchmark_func(args, key, f, flops, ::Type{FT}; print_method_info) wher
113113
print("Time (opt, ref): ($(opt.t_pretty), $(ref.t_pretty)). Key: $key_str\n")
114114
# end
115115
end
116+
correctness = compare(components(opt.result), components(ref.result)) # test correctness
117+
# @show correctness
118+
@show components(opt.result)
119+
@show components(ref.result)
120+
@show correctness
116121
bm = (;
117122
opt,
118123
ref,
@@ -121,7 +126,7 @@ function benchmark_func(args, key, f, flops, ::Type{FT}; print_method_info) wher
121126
flops, # current flops
122127
computed_flops,
123128
reduced_flops,
124-
correctness = compare(opt.result, ref.result), # test correctness
129+
correctness, # test correctness
125130
perf_pass = (opt.time - ref.time)/ref.time*100 < -100, # test performance
126131
)
127132
return bm
@@ -149,6 +154,7 @@ components(x::T) where {T <: Real} = x
149154
components(x) = Geometry.components(x)
150155
compare(x::T, y::T) where {T<: Real} = x y || (x < eps(T)/100 && y < eps(T)/100)
151156
compare(x::T, y::T) where {T <: SMatrix} = all(compare.(x, y))
157+
compare(x::T, y::T) where {T <: Geometry.SimpleSymmetric} = all(compare.(x.lowertriangle, y.lowertriangle))
152158
compare(x::T, y::T) where {T <: SVector} = all(compare.(x, y))
153159
compare(x::T, y::T) where {T <: AxisTensor} = compare(components(x), components(y))
154160

@@ -168,6 +174,9 @@ function test_optimized_functions(::Type{FT}; print_method_info=false) where {FT
168174
end
169175

170176
for key in keys(benchmarks)
177+
if !(benchmarks[key].correctness)
178+
@show key
179+
end
171180
@test benchmarks[key].correctness # test correctness
172181
@test benchmarks[key].Δflops 0 # Don't regress
173182
# @test_broken benchmarks[key].Δflops < 0 # Error on improvements. TODO: fix, this is somehow flakey

test/Geometry/unit_simple_symmetric.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Revise; include(joinpath("test", "Geometry", "unit_simple_symmetric.jl"))
55
using Test
66
using StaticArrays
77
using ClimaCore.Geometry: SimpleSymmetric
8+
import ClimaCore.Geometry
89
using JET
910
simple_symmetric(A::Matrix) = SimpleSymmetric(SMatrix{size(A)..., eltype(A)}(A))
1011

@@ -27,4 +28,11 @@ simple_symmetric(A::Matrix) = SimpleSymmetric(SMatrix{size(A)..., eltype(A)}(A))
2728
A = @SMatrix [1 2; 2 4]
2829
@test SimpleSymmetric(A) / 2 === SimpleSymmetric(A / 2)
2930
@test_opt SimpleSymmetric(A)
31+
@test Geometry.tail_params(typeof(@SMatrix Float32[1 2; 2 4])) == (Float32, SMatrix{2, 2, Float32, 4}, 2, 3)
32+
end
33+
34+
@testset "sizs" begin
35+
for N in (1,2,3,5,8,10)
36+
simple_symmetric(rand(N,N)) # pass in non-symmetric matrix
37+
end
3038
end

0 commit comments

Comments
 (0)