Skip to content

Commit 5b80258

Browse files
pieverKenolcw
authored
implement three-argument similar method (from #94) (#218)
* Implement three-argument `similar` methods Array wrappers such as `OffsetArray` assume that the three argument version exists for its interior arrays. * Apply suggestions from code review Co-authored-by: Lucas C Wilcox <lucas@swirlee.com> * simplify similar signatures * support and test similar and reshape with offsets Co-authored-by: Keno Fischer <keno@alumni.harvard.edu> Co-authored-by: Lucas C Wilcox <lucas@swirlee.com>
1 parent 1581d70 commit 5b80258

File tree

2 files changed

+100
-17
lines changed

2 files changed

+100
-17
lines changed

src/structarray.jl

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,22 +167,26 @@ function Base.IndexStyle(::Type{S}) where {S<:StructArray}
167167
index_type(S) === Int ? IndexLinear() : IndexCartesian()
168168
end
169169

170-
function _undef_array(::Type{T}, sz; unwrap::F = alwaysfalse) where {T, F}
170+
function undef_array(::Type{T}, sz; unwrap::F = alwaysfalse) where {T, F}
171171
if unwrap(T)
172172
return StructArray{T}(undef, sz; unwrap = unwrap)
173173
else
174174
return Array{T}(undef, sz)
175175
end
176176
end
177177

178-
function _similar(v::AbstractArray, ::Type{Z}; unwrap::F = alwaysfalse) where {Z, F}
178+
function similar_array(v::AbstractArray, ::Type{Z}; unwrap::F = alwaysfalse) where {Z, F}
179179
if unwrap(Z)
180-
return buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z)
180+
return buildfromschema(typ -> similar_array(v, typ; unwrap = unwrap), Z)
181181
else
182182
return similar(v, Z)
183183
end
184184
end
185185

186+
function similar_structarray(v::AbstractArray, ::Type{Z}; unwrap::F = alwaysfalse) where {Z, F}
187+
buildfromschema(typ -> similar_array(v, typ; unwrap = unwrap), Z)
188+
end
189+
186190
"""
187191
StructArray{T}(undef, dims; unwrap=T->false)
188192
@@ -204,14 +208,10 @@ julia> StructArray{ComplexF64}(undef, (2,3))
204208
StructArray(::Base.UndefInitializer, sz::Dims)
205209

206210
function StructArray{T}(::Base.UndefInitializer, sz::Dims; unwrap::F = alwaysfalse) where {T, F}
207-
buildfromschema(typ -> _undef_array(typ, sz; unwrap = unwrap), T)
211+
buildfromschema(typ -> undef_array(typ, sz; unwrap = unwrap), T)
208212
end
209213
StructArray{T}(u::Base.UndefInitializer, d::Integer...; unwrap::F = alwaysfalse) where {T, F} = StructArray{T}(u, convert(Dims, d); unwrap = unwrap)
210214

211-
function similar_structarray(v::AbstractArray, ::Type{Z}; unwrap::F = alwaysfalse) where {Z, F}
212-
buildfromschema(typ -> _similar(v, typ; unwrap = unwrap), Z)
213-
end
214-
215215
"""
216216
StructArray(A; unwrap = T->false)
217217
@@ -276,14 +276,34 @@ Base.convert(::Type{StructArray}, v::StructArray) = v
276276
Base.convert(::Type{StructVector}, v::AbstractVector) = StructVector(v)
277277
Base.convert(::Type{StructVector}, v::StructVector) = v
278278

279-
function Base.similar(::Type{<:StructArray{T, <:Any, C}}, sz::Dims) where {T, C}
280-
buildfromschema(typ -> similar(typ, sz), T, C)
279+
# Mimic OffsetArrays signatures
280+
const OffsetAxisKnownLength = Union{Integer, AbstractUnitRange}
281+
const OffsetAxis = Union{OffsetAxisKnownLength, Colon}
282+
283+
const OffsetShapeKnownLength = Tuple{OffsetAxisKnownLength,Vararg{OffsetAxisKnownLength}}
284+
const OffsetShape = Tuple{OffsetAxis,Vararg{OffsetAxis}}
285+
286+
# Helper function to avoid adding too many dispatches to `Base.similar`
287+
function _similar(s::StructArray{T}, ::Type{T}, sz) where {T}
288+
return StructArray{T}(map(typ -> similar(typ, sz), components(s)))
281289
end
282290

283-
Base.similar(s::StructArray, sz::Base.DimOrInd...) = similar(s, Base.to_shape(sz))
284-
Base.similar(s::StructArray) = similar(s, Base.to_shape(axes(s)))
285-
function Base.similar(s::StructArray{T}, sz::Tuple) where {T}
286-
StructArray{T}(map(typ -> similar(typ, sz), components(s)))
291+
function _similar(s::StructArray{T}, S::Type, sz) where {T}
292+
# If not specified, we don't really know what kind of array to use for each
293+
# interior type, so we just pick the first one arbitrarily. If users need
294+
# something else, they need to be more specific.
295+
c1 = first(components(s))
296+
return isnonemptystructtype(S) ? buildfromschema(typ -> similar(c1, typ, sz), S) : similar(c1, S, sz)
297+
end
298+
299+
for type in (:Dims, :OffsetShapeKnownLength)
300+
@eval function Base.similar(::Type{<:StructArray{T, N, C}}, sz::$(type)) where {T, N, C}
301+
return buildfromschema(typ -> similar(typ, sz), T, C)
302+
end
303+
304+
@eval function Base.similar(s::StructArray, S::Type, sz::$(type))
305+
return _similar(s, S, sz)
306+
end
287307
end
288308

289309
@deprecate fieldarrays(x) StructArrays.components(x)
@@ -437,8 +457,10 @@ end
437457

438458
Base.copy(s::StructArray{T}) where {T} = StructArray{T}(map(copy, components(s)))
439459

440-
function Base.reshape(s::StructArray{T}, d::Dims) where {T}
441-
StructArray{T}(map(x -> reshape(x, d), components(s)))
460+
for type in (:Dims, :OffsetShape)
461+
@eval function Base.reshape(s::StructArray{T}, d::$(type)) where {T}
462+
StructArray{T}(map(x -> reshape(x, d), components(s)))
463+
end
442464
end
443465

444466
function showfields(io::IO, fields::NTuple{N, Any}) where N

test/runtests.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using StructArrays
22
using StructArrays: staticschema, iscompatible, _promote_typejoin, append!!
3-
using OffsetArrays: OffsetArray
3+
using OffsetArrays: OffsetArray, OffsetVector, OffsetMatrix
44
using StaticArrays
55
import Tables, PooledArrays, WeakRefStrings
66
using TypedTables: Table
@@ -318,13 +318,57 @@ end
318318
s = similar(t)
319319
@test eltype(s) == NamedTuple{(:a, :b), Tuple{Float64, Bool}}
320320
@test size(s) == (10,)
321+
@test s isa StructArray
322+
321323
t = StructArray(a = rand(10, 2), b = rand(Bool, 10, 2))
322324
s = similar(t, 3, 5)
323325
@test eltype(s) == NamedTuple{(:a, :b), Tuple{Float64, Bool}}
324326
@test size(s) == (3, 5)
327+
@test s isa StructArray
328+
325329
s = similar(t, (3, 5))
326330
@test eltype(s) == NamedTuple{(:a, :b), Tuple{Float64, Bool}}
327331
@test size(s) == (3, 5)
332+
@test s isa StructArray
333+
334+
s = similar(t, (0:2, 5))
335+
@test eltype(s) == NamedTuple{(:a, :b), Tuple{Float64, Bool}}
336+
@test axes(s) == (0:2, 1:5)
337+
@test s isa StructArray
338+
@test s.a isa OffsetArray
339+
@test s.b isa OffsetArray
340+
341+
s = similar(t, ComplexF64, 10)
342+
@test s isa StructArray{ComplexF64, 1, NamedTuple{(:re, :im), Tuple{Vector{Float64}, Vector{Float64}}}}
343+
@test size(s) == (10,)
344+
345+
s = similar(t, ComplexF64, 0:9)
346+
VectorType = OffsetVector{Float64, Vector{Float64}}
347+
@test s isa StructArray{ComplexF64, 1, NamedTuple{(:re, :im), Tuple{VectorType, VectorType}}}
348+
@test axes(s) == (0:9,)
349+
350+
s = similar(t, Float32, 2, 2)
351+
@test s isa Matrix{Float32}
352+
@test size(s) == (2, 2)
353+
354+
s = similar(t, Float32, 0:1, 2)
355+
@test s isa OffsetMatrix{Float32, Matrix{Float32}}
356+
@test axes(s) == (0:1, 1:2)
357+
end
358+
359+
@testset "similar type" begin
360+
t = StructArray(a = rand(10), b = rand(10))
361+
T = typeof(t)
362+
s = similar(T, 3)
363+
@test typeof(s) == typeof(t)
364+
@test size(s) == (3,)
365+
366+
s = similar(T, 0:2)
367+
@test axes(s) == (0:2,)
368+
@test s isa StructArray{NamedTuple{(:a, :b), Tuple{Float64, Float64}}}
369+
VectorType = OffsetVector{Float64, Vector{Float64}}
370+
@test s.a isa VectorType
371+
@test s.b isa VectorType
328372
end
329373

330374
@testset "empty" begin
@@ -803,6 +847,10 @@ end
803847
rs = reshape(s, (2, 2))
804848
@test rs.a == [1 3; 2 4]
805849
@test rs.b == ["a" "c"; "b" "d"]
850+
851+
rs = reshape(s, (0:1, :))
852+
@test rs.a == OffsetArray([1 3; 2 4], (-1, 0))
853+
@test rs.b == OffsetArray(["a" "c"; "b" "d"], (-1, 0))
806854
end
807855

808856
@testset "lazy" begin
@@ -1091,3 +1139,16 @@ end
10911139
C = map(zero, NamedTuple{(:a, :b, :c)}(map(zero, fieldtypes(types))))
10921140
@test A === C
10931141
end
1142+
1143+
@testset "OffsetArray zero" begin
1144+
s = StructArray{ComplexF64}((rand(2), rand(2)))
1145+
soff = OffsetArray(s, 0:1)
1146+
@test isa(parent(zero(soff)), StructArray)
1147+
end
1148+
1149+
# issue #230
1150+
@testset "StaticArray zero" begin
1151+
u = StructArray([SVector(1.0)])
1152+
@test zero(u) == StructArray([SVector(0.0)])
1153+
@test typeof(zero(u)) == typeof(StructArray([SVector(0.0)]))
1154+
end

0 commit comments

Comments
 (0)