@@ -898,17 +898,25 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
898
898
@test t. b. d isa Array
899
899
end
900
900
901
- struct MyArray{T,N} <: AbstractArray{T,N}
902
- A:: Array{T,N}
903
- end
904
- MyArray {T} (:: UndefInitializer , sz:: Dims ) where T = MyArray (Array {T} (undef, sz))
905
- Base. IndexStyle (:: Type{<:MyArray} ) = IndexLinear ()
906
- Base. getindex (A:: MyArray , i:: Int ) = A. A[i]
907
- Base. setindex! (A:: MyArray , val, i:: Int ) = A. A[i] = val
908
- Base. size (A:: MyArray ) = Base. size (A. A)
909
- Base. BroadcastStyle (:: Type{<:MyArray} ) = Broadcast. ArrayStyle {MyArray} ()
910
- Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}} , :: Type{ElType} ) where ElType =
911
- MyArray {ElType} (undef, size (bc))
901
+ for S in (1 , 2 , 3 )
902
+ MyArray = Symbol (:MyArray , S)
903
+ @eval begin
904
+ struct $ MyArray{T,N} <: AbstractArray{T,N}
905
+ A:: Array{T,N}
906
+ end
907
+ $ MyArray {T} (:: UndefInitializer , sz:: Dims ) where T = $ MyArray (Array {T} (undef, sz))
908
+ Base. IndexStyle (:: Type{<:$MyArray} ) = IndexLinear ()
909
+ Base. getindex (A:: $MyArray , i:: Int ) = A. A[i]
910
+ Base. setindex! (A:: $MyArray , val, i:: Int ) = A. A[i] = val
911
+ Base. size (A:: $MyArray ) = Base. size (A. A)
912
+ Base. BroadcastStyle (:: Type{<:$MyArray} ) = Broadcast. ArrayStyle {$MyArray} ()
913
+ end
914
+ end
915
+ Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}} , :: Type{ElType} ) where ElType =
916
+ MyArray1 {ElType} (undef, size (bc))
917
+ Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}} , :: Type{ElType} ) where ElType =
918
+ MyArray2 {ElType} (undef, size (bc))
919
+ Base. BroadcastStyle (:: Broadcast.ArrayStyle{MyArray1} , :: Broadcast.ArrayStyle{MyArray3} ) = Broadcast. ArrayStyle {MyArray1} ()
912
920
913
921
@testset " broadcast" begin
914
922
s = StructArray {ComplexF64} ((rand (2 ,2 ), rand (2 ,2 )))
@@ -926,24 +934,59 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
926
934
# used inside of broadcast but we also test it here explicitly
927
935
@test isa (@inferred (Base. dataids (s)), NTuple{N, UInt} where {N})
928
936
929
- s = StructArray {ComplexF64} ((MyArray (rand (2 )), MyArray (rand (2 ))))
930
- @test_throws MethodError s .+ s
937
+ # Make sure we can handle style with similar defined
938
+ # s1 and s2 has similar defined, but s3 not
939
+ # s2 are conflict with s1 and s3.
940
+ s1 = StructArray {ComplexF64} ((MyArray1 (rand (2 )), MyArray1 (rand (2 ))))
941
+ s2 = StructArray {ComplexF64} ((MyArray2 (rand (2 )), MyArray2 (rand (2 ))))
942
+ s3 = StructArray {ComplexF64} ((MyArray3 (rand (2 )), MyArray3 (rand (2 ))))
943
+
944
+ function _test_similar (a, b)
945
+ flag = false
946
+ try
947
+ c = StructArray {ComplexF64} ((a. re .+ b. re, a. im .+ b. im))
948
+ flag = true
949
+ catch
950
+ end
951
+ if flag
952
+ @test typeof (@inferred (a .+ b)) == typeof (c)
953
+ else
954
+ @test_throws MethodError a .+ b
955
+ end
956
+ end
957
+ for s in (s1,s2,s3), s′ in (s1,s2,s3)
958
+ _test_similar (s, s′)
959
+ end
931
960
932
961
# test for dimensionality track
962
+ s = s1
933
963
@test Base. broadcasted (+ , s, s) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
934
- @test Base. broadcasted (+ , s, [1 ,2 ]) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
964
+ @test Base. broadcasted (+ , s, [1 ,2 ]) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{1} }
935
965
@test Base. broadcasted (+ , s, [1 ;;2 ]) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{2} }
936
966
@test Base. broadcasted (+ , [1 ;;;2 ], s) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{3} }
937
-
938
- a = StructArray ([1 ;2 + im])
939
- b = StructArray ([1 ;;2 + im])
940
- @test a .+ b == a .+ collect (b) == collect (a) .+ b == collect (a) .+ collect (b)
967
+ @test Base. broadcasted (+ , s, MyArray1 (rand (2 ))) isa Broadcast. Broadcasted{<: Broadcast.AbstractArrayStyle{Any} }
941
968
942
969
# issue #185
943
970
A = StructArray (randn (ComplexF64, 3 , 3 ))
944
971
B = randn (ComplexF64, 3 , 3 )
945
972
c = StructArray (randn (ComplexF64, 3 ))
946
973
@test (A .= B .* c) === A
974
+
975
+ # ambiguity check (can we do this better?)
976
+ function _test (a, b)
977
+ if a isa StructArray || b isa StructArray
978
+ d = @inferred a .+ b
979
+ @test d == collect (a) .+ collect (b)
980
+ @test d isa StructArray
981
+ end
982
+ end
983
+ testset = StructArray ([1 ;2 + im]), StructArray ([1 2 + im]), 1 : 2 , (1 ,2 ), (@SArray [1 2 ])
984
+ for aa in testset, bb in testset
985
+ _test (aa, bb)
986
+ end
987
+ a = StructArray ([1 ;2 + im])
988
+ b = StructArray ([1 2 + im])
989
+ @test @inferred (a .+ b .+ a .* a' .+ (1 ,2 ) .+ (1 : 2 ) .- b' ) isa StructArray
947
990
end
948
991
949
992
@testset " staticarrays" begin
0 commit comments