-
Notifications
You must be signed in to change notification settings - Fork 42
Generalize StructArray
's broadcast.
#215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
fef311e
a78cab2
a6fe8a5
c165c61
e711ebe
8c83220
10e6442
960e1c7
14c7a84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings | |
using TypedTables: Table | ||
using DataAPI: refarray, refvalue | ||
using Adapt: adapt, Adapt | ||
using JLArrays | ||
using Test | ||
|
||
using Documenter: doctest | ||
|
@@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) | |
@test t.b.d isa Array | ||
end | ||
|
||
struct MyArray{T,N} <: AbstractArray{T,N} | ||
A::Array{T,N} | ||
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. | ||
# 1. `MyArray1` and `MyArray1` have `similar` defined. | ||
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`. | ||
# 2. `MyArray3` has no `similar` defined. | ||
# We use it to simulate `BroadcastStyle` overloading `Base.copy`. | ||
# 3. Their resolved style could be summaryized as (`-` means conflict) | ||
# | MyArray1 | MyArray2 | MyArray3 | Array | ||
# ------------------------------------------------------------- | ||
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1 | ||
# MyArray2 | - | MyArray2 | - | MyArray2 | ||
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3 | ||
# Array | MyArray1 | Array | MyArray3 | Array | ||
|
||
for S in (1, 2, 3) | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MyArray = Symbol(:MyArray, S) | ||
@eval begin | ||
struct $MyArray{T,N} <: AbstractArray{T,N} | ||
A::Array{T,N} | ||
end | ||
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz)) | ||
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear() | ||
Base.getindex(A::$MyArray, i::Int) = A.A[i] | ||
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val | ||
Base.size(A::$MyArray) = Base.size(A.A) | ||
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}() | ||
end | ||
end | ||
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) | ||
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() | ||
Base.getindex(A::MyArray, i::Int) = A.A[i] | ||
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val | ||
Base.size(A::MyArray) = Base.size(A.A) | ||
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() | ||
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = | ||
MyArray{ElType}(undef, size(bc)) | ||
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType = | ||
MyArray1{ElType}(undef, size(bc)) | ||
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = | ||
MyArray2{ElType}(undef, size(bc)) | ||
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() | ||
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S | ||
|
||
@testset "broadcast" begin | ||
s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) | ||
|
@@ -1128,19 +1151,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El | |
# used inside of broadcast but we also test it here explicitly | ||
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) | ||
|
||
s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) | ||
@test_throws MethodError s .+ s | ||
# Make sure we can handle style with similar defined | ||
# And we can handle most conflicts | ||
# `s1` and `s2` have similar defined, but `s3` does not | ||
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` | ||
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) | ||
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) | ||
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) | ||
s4 = StructArray{ComplexF64}((rand(2), rand(2))) | ||
|
||
function _test_similar(a, b, c) | ||
try | ||
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im)) | ||
@test typeof(a .+ b .- c) == typeof(d) | ||
catch | ||
@test_throws MethodError a .+ b .- c | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This had escaped me before, but I'm wondering: could it be possible to be explicit here on which it is (correct result or method error) based on the input types? Ideally one would want to explicitly test what one is getting, so I would suggest to remove the helper function if s2 in (s, s′, s″) && (s1 in (s, s′, s″) || s3 in (s, s′, s″))
# test method error
else
# test correct result
end in the loop body. (I'm not sure whether that's the correct criterion.) |
||
end | ||
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4) | ||
_test_similar(s, s′, s″) | ||
end | ||
|
||
# test for dimensionality track | ||
s = s1 | ||
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} | ||
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} | ||
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} | ||
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} | ||
|
||
a = StructArray([1;2+im]) | ||
b = StructArray([1;;2+im]) | ||
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) | ||
@test a .+ Any[1] isa StructArray | ||
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} | ||
|
||
# issue #185 | ||
A = StructArray(randn(ComplexF64, 3, 3)) | ||
|
@@ -1155,6 +1193,53 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El | |
|
||
@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)] | ||
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3] | ||
@test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)] | ||
@test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3] | ||
|
||
@testset "ambiguity check" begin | ||
function _test(a, b, c) | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if a isa StructArray || b isa StructArray || c isa StructArray | ||
d = @inferred a .+ b .- c | ||
@test d == collect(a) .+ collect(b) .- collect(c) | ||
@test d isa StructArray | ||
end | ||
end | ||
testset = Any[StructArray([1;2+im]), | ||
1:2, | ||
(1,2), | ||
StructArray(@SArray [1 1+2im]), | ||
(@SArray [1 2]) | ||
] | ||
for aa in testset, bb in testset, cc in testset | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_test(aa, bb, cc) | ||
end | ||
end | ||
|
||
@testset "StructStaticArray" begin | ||
bclog(s) = log.(s) | ||
test_allocated(f, s) = @test (@allocated f(s)) == 0 | ||
a = @SMatrix [float(i) for i in 1:10, j in 1:10] | ||
b = @SMatrix [0. for i in 1:10, j in 1:10] | ||
s = StructArray{ComplexF64}((a , b)) | ||
@test (@inferred bclog(s)) isa typeof(s) | ||
test_allocated(bclog, s) | ||
N5N3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix | ||
bc = Base.broadcasted(+, s, s); | ||
bc = Base.broadcasted(+, bc, bc, s); | ||
@test @inferred(Broadcast.axes(bc)) === axes(s) | ||
end | ||
|
||
@testset "StructJLArray" begin | ||
bcabs(a) = abs.(a) | ||
bcmul2(a) = 2 .* a | ||
a = StructArray(randn(ComplexF32, 10, 10)) | ||
sa = jl(a) | ||
backend = StructArrays.GPUArraysCore.backend | ||
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im) | ||
@test collect(@inferred(bcabs(sa))) == bcabs(a) | ||
@test @inferred(bcmul2(sa)) isa StructArray | ||
@test (sa .+= 1) isa StructArray | ||
end | ||
end | ||
|
||
@testset "map" begin | ||
|
Uh oh!
There was an error while loading. Please reload this page.