Skip to content

Commit 26d0e05

Browse files
authored
Default implementation of similar_type for FieldArray (#731)
This should free users from defining similar_type in many cases, though they'll still need to do so when their type is parametric on the eltype. (There's no general way for us to know how to reparameterize such user-defined types.)
1 parent dd361f8 commit 26d0e05

File tree

3 files changed

+62
-17
lines changed

3 files changed

+62
-17
lines changed

src/FieldArray.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable
66
`FieldArray` will be as performant as an `SArray` of similar length and element type,
77
while a mutable `FieldArray` will behave similarly to an `MArray`.
88
9-
For example:
9+
Note that you must define the fields of any `FieldArray` subtype in column major order. If you
10+
want to use an alternative ordering you will need to pay special attention in providing your
11+
own definitions of `getindex`, `setindex!` and tuple conversion.
12+
13+
If you define a `FieldArray` which is parametric on the element type you should
14+
consider defining `similar_type` as in the `FieldVector` example.
15+
16+
17+
# Example
1018
1119
struct Stiffness <: FieldArray{Tuple{2,2,2,2}, Float64, 4}
1220
xxxx::Float64
@@ -26,10 +34,6 @@ For example:
2634
xyyy::Float64
2735
yyyy::Float64
2836
end
29-
30-
Note that you must define the fields of any `FieldArray` subtype in column major order. If you
31-
want to use an alternative ordering you will need to pay special attention in providing your
32-
own definitions of `getindex`, `setindex!` and tuple conversion.
3337
"""
3438
abstract type FieldArray{N, T, D} <: StaticArray{N, T, D} end
3539

@@ -41,7 +45,13 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable
4145
`FieldMatrix` will be as performant as an `SMatrix` of similar length and element type,
4246
while a mutable `FieldMatrix` will behave similarly to an `MMatrix`.
4347
44-
For example:
48+
Note that the fields of any subtype of `FieldMatrix` must be defined in column
49+
major order unless you are willing to implement your own `getindex`.
50+
51+
If you define a `FieldMatrix` which is parametric on the element type you
52+
should consider defining `similar_type` as in the `FieldVector` example.
53+
54+
# Example
4555
4656
struct Stress <: FieldMatrix{3, 3, Float64}
4757
xx::Float64
@@ -67,13 +77,12 @@ For example:
6777
2.0 5.0 8.0
6878
3.0 6.0 9.0
6979
70-
7180
will give you the transpose of what the multi-argument formatting suggests. For clarity,
7281
you may consider using the alternative
7382
74-
sigma = Stress(@SArray[1.0 2.0 3.0;
75-
4.0 5.0 6.0;
76-
7.0 8.0 9.0])
83+
sigma = Stress(SA[1.0 2.0 3.0;
84+
4.0 5.0 6.0;
85+
7.0 8.0 9.0])
7786
"""
7887
abstract type FieldMatrix{N1, N2, T} <: FieldArray{Tuple{N1, N2}, T, 2} end
7988

@@ -85,13 +94,19 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable
8594
`FieldVector` will be as performant as an `SVector` of similar length and element type,
8695
while a mutable `FieldVector` will behave similarly to an `MVector`.
8796
88-
For example:
97+
If you define a `FieldVector` which is parametric on the element type you
98+
should consider defining `similar_type` to preserve your array type through
99+
array operations as in the example below.
100+
101+
# Example
89102
90-
struct Point3D <: FieldVector{3, Float64}
91-
x::Float64
92-
y::Float64
93-
z::Float64
103+
struct Vec3D{T} <: FieldVector{3, T}
104+
x::T
105+
y::T
106+
z::T
94107
end
108+
109+
StaticArrays.similar_type(::Type{<:Vec3D}, ::Type{T}, s::Size{(3,)}) where {T} = Vec3D{T}
95110
"""
96111
abstract type FieldVector{N, T} <: FieldArray{Tuple{N}, T, 1} end
97112

@@ -109,3 +124,15 @@ end
109124
Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a)
110125
Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} =
111126
Ptr{T}(Base.unsafe_convert(Ptr{FA}, m))
127+
128+
# We can automatically preserve FieldArrays in array operations which do not
129+
# change their eltype or Size. This should cover all non-parametric FieldArray,
130+
# but for those which are parametric on the eltype the user will still need to
131+
# overload similar_type themselves.
132+
similar_type(::Type{A}, ::Type{T}, S::Size) where {N, T, A<:FieldArray{N, T}} =
133+
_fieldarray_similar_type(A, T, S, Size(A))
134+
135+
# Extra layer of dispatch to match NewSize and OldSize
136+
_fieldarray_similar_type(A, T, NewSize::S, OldSize::S) where {S} = A
137+
_fieldarray_similar_type(A, T, NewSize, OldSize) =
138+
default_similar_type(T, NewSize, length_val(NewSize))

test/FieldMatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
zz::Float64
1414
end
1515

16-
StaticArrays.similar_type(::Type{Tensor3x3}, ::Type{Float64}, s::Size{(3,3)}) = Tensor3x3
16+
# No need to define similar_type for non-parametric FieldMatrix (#792)
1717
end)
1818

1919
p = Tensor3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)

test/FieldVector.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
z::Float64
88
end
99

10-
StaticArrays.similar_type(::Type{Point3D}, ::Type{Float64}, s::Size{(3,)}) = Point3D
10+
# No need to define similar_type for non-parametric FieldVector (#792)
1111
end)
1212

1313
p = Point3D(1.0, 2.0, 3.0)
@@ -29,6 +29,8 @@
2929
0.0 0.0 2.0]
3030

3131
@test @inferred(m*p) === Point3D(2.0, 4.0, 6.0)
32+
@test @inferred(SA[2.0 0.0 0.0;
33+
0.0 2.0 0.0]*p) === SVector((2.0, 4.0))
3234

3335
@test @inferred(similar_type(Point3D)) == Point3D
3436
@test @inferred(similar_type(Point3D, Float64)) == Point3D
@@ -92,4 +94,20 @@
9294
@test length(x[1]) == 2
9395
@test x.x == (1, 2)
9496
end
97+
98+
@testset "FieldVector with parametric eltype and without similar_type" begin
99+
eval(quote
100+
struct FVT{T} <: FieldVector{2, T}
101+
x::T
102+
y::T
103+
end
104+
105+
# No similar_type defined - test fallback codepath
106+
end)
107+
108+
@test @inferred(similar_type(FVT{Float64}, Float32)) == SVector{2,Float32} # Fallback code path
109+
@test @inferred(similar_type(FVT{Float64}, Size(2))) == FVT{Float64}
110+
@test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64}
111+
@test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32}
112+
end
95113
end

0 commit comments

Comments
 (0)