Skip to content

Commit fa21f52

Browse files
gpu fixes
1 parent 59fe268 commit fa21f52

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

src/Geometry/axistensors.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ struct AxisTensor{
137137
N,
138138
A <: NTuple{N, AbstractAxis},
139139
S <: Union{
140-
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
141140
SimpleSymmetric{N, T},
142141
StaticArray{<:Tuple, T, N},
143142
},
@@ -152,7 +151,6 @@ AxisTensor(
152151
) where {
153152
A <: Tuple{Vararg{AbstractAxis}},
154153
S <: Union{
155-
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
156154
SimpleSymmetric{N, T},
157155
StaticArray{<:Tuple, T, N},
158156
},

src/Geometry/simple_symmetric.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,36 @@ triangular_nonzeros(::SMatrix{N}) where {N} = Int(N * (N + 1) / 2)
5555
triangular_nonzeros(::Type{<:SMatrix{N}}) where {N} = Int(N * (N + 1) / 2)
5656
tail_params(::Type{S}) where {N,T, S<:SMatrix{N,N,T}} = (T, S, N, triangular_nonzeros(S))
5757

58-
function SimpleSymmetric(A::SMatrix)
59-
@assert size(A, 1) == size(A, 2)
60-
N = size(A, 1)
61-
nd = ndims(A)
62-
T = eltype(A)
63-
ci = ntuple(i -> 1:size(A, i), nd)
64-
upper_inds = filter(I -> I.I[2] I.I[1], CartesianIndices(ci))
65-
L = triangular_nonzeros(A)
66-
upper_triang = SVector{L}(map(I -> A[I], upper_inds))
67-
SimpleSymmetric{N, T, L}(upper_triang)
58+
# function SimpleSymmetric(A::SMatrix)
59+
# @assert size(A, 1) == size(A, 2)
60+
# N = size(A, 1)
61+
# nd = ndims(A)
62+
# T = eltype(A)
63+
# ci = ntuple(i -> 1:size(A, i), nd)
64+
# upper_inds = filter(I -> I.I[2] ≥ I.I[1], CartesianIndices(ci))
65+
# L = triangular_nonzeros(A)
66+
# upper_triang = SVector{L}(map(I -> A[I], upper_inds))
67+
# SimpleSymmetric{N, T, L}(upper_triang)
68+
# end
69+
70+
@generated function SimpleSymmetric(A::S) where {S <: SMatrix}
71+
N = size(S, 1)
72+
L = triangular_nonzeros(S)
73+
_check_simple_symmetric_parameters(Val(N), Val(L))
74+
expr = Vector{Expr}(undef, L)
75+
T = eltype(S)
76+
i = 0
77+
for col in 1:N, row in 1:N
78+
if col row
79+
expr[i += 1] = :(A[$row, $col])
80+
end
81+
end
82+
quote
83+
Base.@_inline_meta
84+
@inbounds return SimpleSymmetric{$N, $T, $L}(
85+
SVector{$L, $T}(tuple($(expr...))),
86+
)
87+
end
6888
end
6989

7090
@inline function _check_simple_symmetric_parameters(

0 commit comments

Comments
 (0)