Skip to content

Commit 5e67179

Browse files
Merge pull request #2368 from apkille/master
Solving on custom array types that are not `AbstractArray`s
2 parents c5c0150 + 61a207c commit 5e67179

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

test/interface/noindex_tests.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using OrdinaryDiffEq, RecursiveArrayTools, LinearAlgebra
2+
13
struct NoIndexArray{T, N} <: AbstractArray{T, N}
24
x::Array{T, N}
35
end
@@ -46,11 +48,95 @@ function Base.show(io::IO, ::MIME"text/plain", x::NoIndexArray)
4648
Base.print_array(io, x.x)
4749
end
4850

49-
using OrdinaryDiffEq
5051
prob = ODEProblem((du, u, p, t) -> copyto!(du, u), NoIndexArray(ones(10, 10)), (0.0, 10.0))
5152
algs = [Tsit5(), BS3(), Vern9(), DP5()]
5253
for alg in algs
5354
sol = @test_nowarn solve(prob, alg)
5455
@test_nowarn sol(0.1)
5556
@test_nowarn sol(similar(prob.u0), 0.1)
5657
end
58+
59+
60+
struct CustomArray{T, N}
61+
x::Array{T, N}
62+
end
63+
Base.size(x::CustomArray) = size(x.x)
64+
Base.axes(x::CustomArray) = axes(x.x)
65+
Base.ndims(x::CustomArray) = ndims(x.x)
66+
Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N
67+
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
68+
Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N}))
69+
Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...))
70+
Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x))
71+
Base.copy(x::CustomArray) = CustomArray(copy(x.x))
72+
Base.length(x::CustomArray) = length(x.x)
73+
Base.isempty(x::CustomArray) = isempty(x.x)
74+
Base.eltype(x::CustomArray) = eltype(x.x)
75+
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
76+
Base.fill!(x::CustomArray, y) = CustomArray(fill!(x.x, y))
77+
Base.getindex(x::CustomArray, i) = getindex(x.x, i)
78+
Base.setindex!(x::CustomArray, v, idx) = setindex!(x.x, v, idx)
79+
Base.eachindex(x::CustomArray) = eachindex(x.x)
80+
Base.mapreduce(f, op, x::CustomArray; kwargs...) = mapreduce(f, op, x.x; kwargs...)
81+
Base.any(f::Function, x::CustomArray; kwargs...) = any(f, x.x; kwargs...)
82+
Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...)
83+
Base.similar(x::CustomArray, t) = CustomArray(similar(x.x, t))
84+
Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x
85+
Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x)
86+
Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y)
87+
LinearAlgebra.norm(x::CustomArray) = norm(x.x)
88+
89+
struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end
90+
CustomStyle(::Val{N}) where N = CustomStyle{N}()
91+
CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}()
92+
Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}()
93+
Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}()
94+
Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc)))
95+
Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i]
96+
Base.Broadcast.extrude(x::CustomArray) = x
97+
Base.Broadcast.broadcastable(x::CustomArray) = x
98+
99+
@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:CustomStyle})
100+
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
101+
bc′ = Base.Broadcast.preprocess(dest, bc)
102+
dest′ = dest.x
103+
@simd for I in 1:length(dest′)
104+
@inbounds dest′[I] = bc′[I]
105+
end
106+
return dest
107+
end
108+
@inline function Base.copy(bc::Base.Broadcast.Broadcasted{<:CustomStyle})
109+
bcf = Broadcast.flatten(bc)
110+
x = find_x(bcf)
111+
data = zeros(eltype(x), size(x))
112+
@inbounds @simd for I in 1:length(x)
113+
data[I] = bcf[I]
114+
end
115+
return CustomArray(data)
116+
end
117+
find_x(bc::Broadcast.Broadcasted) = find_x(bc.args)
118+
find_x(args::Tuple) = find_x(find_x(args[1]), Base.tail(args))
119+
find_x(x) = x
120+
find_x(::Any, rest) = find_x(rest)
121+
find_x(x::CustomArray, rest) = x.x
122+
123+
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::CustomArray) = eltype(x)
124+
RecursiveArrayTools.recursivecopy!(dest::CustomArray, src::CustomArray) = copyto!(dest, src)
125+
RecursiveArrayTools.recursivecopy(x::CustomArray) = copy(x)
126+
RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a)
127+
128+
Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x)
129+
130+
Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x))
131+
function Base.show(io::IO, ::MIME"text/plain", x::CustomArray)
132+
println(io, Base.summary(x), ":")
133+
Base.print_array(io, x.x)
134+
end
135+
136+
prob = ODEProblem((du, u, p, t) -> copyto!(du, u), CustomArray(ones(10)), (0.0, 10.0))
137+
algs = [Tsit5(), BS3(), Vern9(), DP5()]
138+
for alg in algs
139+
sol = @test_nowarn solve(prob, alg)
140+
@test_nowarn sol(0.1)
141+
@test_nowarn sol(similar(prob.u0), 0.1)
142+
end

0 commit comments

Comments
 (0)