1
+ using OrdinaryDiffEq, RecursiveArrayTools, LinearAlgebra
2
+
1
3
struct NoIndexArray{T, N} <: AbstractArray{T, N}
2
4
x:: Array{T, N}
3
5
end
@@ -46,11 +48,95 @@ function Base.show(io::IO, ::MIME"text/plain", x::NoIndexArray)
46
48
Base. print_array (io, x. x)
47
49
end
48
50
49
- using OrdinaryDiffEq
50
51
prob = ODEProblem ((du, u, p, t) -> copyto! (du, u), NoIndexArray (ones (10 , 10 )), (0.0 , 10.0 ))
51
52
algs = [Tsit5 (), BS3 (), Vern9 (), DP5 ()]
52
53
for alg in algs
53
54
sol = @test_nowarn solve (prob, alg)
54
55
@test_nowarn sol (0.1 )
55
56
@test_nowarn sol (similar (prob. u0), 0.1 )
56
57
end
58
+
59
+
60
+ struct CustomArray{T, N}
61
+ x:: Array{T, N}
62
+ end
63
+ Base. size (x:: CustomArray ) = size (x. x)
64
+ Base. axes (x:: CustomArray ) = axes (x. x)
65
+ Base. ndims (x:: CustomArray ) = ndims (x. x)
66
+ Base. ndims (:: Type{<:CustomArray{T,N}} ) where {T,N} = N
67
+ Base. zero (x:: CustomArray ) = CustomArray (zero (x. x))
68
+ Base. zero (:: Type{<:CustomArray{T,N}} ) where {T,N} = CustomArray (zero (Array{T,N}))
69
+ Base. similar (x:: CustomArray , dims:: Union{Integer, AbstractUnitRange} ...) = CustomArray (similar (x. x, dims... ))
70
+ Base. copyto! (x:: CustomArray , y:: CustomArray ) = CustomArray (copyto! (x. x, y. x))
71
+ Base. copy (x:: CustomArray ) = CustomArray (copy (x. x))
72
+ Base. length (x:: CustomArray ) = length (x. x)
73
+ Base. isempty (x:: CustomArray ) = isempty (x. x)
74
+ Base. eltype (x:: CustomArray ) = eltype (x. x)
75
+ Base. zero (x:: CustomArray ) = CustomArray (zero (x. x))
76
+ Base. fill! (x:: CustomArray , y) = CustomArray (fill! (x. x, y))
77
+ Base. getindex (x:: CustomArray , i) = getindex (x. x, i)
78
+ Base. setindex! (x:: CustomArray , v, idx) = setindex! (x. x, v, idx)
79
+ Base. eachindex (x:: CustomArray ) = eachindex (x. x)
80
+ Base. mapreduce (f, op, x:: CustomArray ; kwargs... ) = mapreduce (f, op, x. x; kwargs... )
81
+ Base. any (f:: Function , x:: CustomArray ; kwargs... ) = any (f, x. x; kwargs... )
82
+ Base. all (f:: Function , x:: CustomArray ; kwargs... ) = all (f, x. x; kwargs... )
83
+ Base. similar (x:: CustomArray , t) = CustomArray (similar (x. x, t))
84
+ Base.:(== )(x:: CustomArray , y:: CustomArray ) = x. x == y. x
85
+ Base.:(* )(x:: Number , y:: CustomArray ) = CustomArray (x* y. x)
86
+ Base.:(/ )(x:: CustomArray , y:: Number ) = CustomArray (x. x/ y)
87
+ LinearAlgebra. norm (x:: CustomArray ) = norm (x. x)
88
+
89
+ struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end
90
+ CustomStyle (:: Val{N} ) where N = CustomStyle {N} ()
91
+ CustomStyle {M} (:: Val{N} ) where {N,M} = NoIndexStyle {N} ()
92
+ Base. BroadcastStyle (:: Type{<:CustomArray{T,N}} ) where {T,N} = CustomStyle {N} ()
93
+ Broadcast. BroadcastStyle (:: CustomStyle{N} , :: Broadcast.DefaultArrayStyle{0} ) where {N} = CustomStyle {N} ()
94
+ Base. similar (bc:: Base.Broadcast.Broadcasted{CustomStyle{N}} , :: Type{ElType} ) where {N, ElType} = CustomArray (similar (Array{ElType, N}, axes (bc)))
95
+ Base. Broadcast. _broadcast_getindex (x:: CustomArray , i) = x. x[i]
96
+ Base. Broadcast. extrude (x:: CustomArray ) = x
97
+ Base. Broadcast. broadcastable (x:: CustomArray ) = x
98
+
99
+ @inline function Base. copyto! (dest:: CustomArray , bc:: Base.Broadcast.Broadcasted{<:CustomStyle} )
100
+ axes (dest) == axes (bc) || throwdm (axes (dest), axes (bc))
101
+ bc′ = Base. Broadcast. preprocess (dest, bc)
102
+ dest′ = dest. x
103
+ @simd for I in 1 : length (dest′)
104
+ @inbounds dest′[I] = bc′[I]
105
+ end
106
+ return dest
107
+ end
108
+ @inline function Base. copy (bc:: Base.Broadcast.Broadcasted{<:CustomStyle} )
109
+ bcf = Broadcast. flatten (bc)
110
+ x = find_x (bcf)
111
+ data = zeros (eltype (x), size (x))
112
+ @inbounds @simd for I in 1 : length (x)
113
+ data[I] = bcf[I]
114
+ end
115
+ return CustomArray (data)
116
+ end
117
+ find_x (bc:: Broadcast.Broadcasted ) = find_x (bc. args)
118
+ find_x (args:: Tuple ) = find_x (find_x (args[1 ]), Base. tail (args))
119
+ find_x (x) = x
120
+ find_x (:: Any , rest) = find_x (rest)
121
+ find_x (x:: CustomArray , rest) = x. x
122
+
123
+ RecursiveArrayTools. recursive_unitless_bottom_eltype (x:: CustomArray ) = eltype (x)
124
+ RecursiveArrayTools. recursivecopy! (dest:: CustomArray , src:: CustomArray ) = copyto! (dest, src)
125
+ RecursiveArrayTools. recursivecopy (x:: CustomArray ) = copy (x)
126
+ RecursiveArrayTools. recursivefill! (x:: CustomArray , a) = fill! (x, a)
127
+
128
+ Base. show_vector (io:: IO , x:: CustomArray ) = Base. show_vector (io, x. x)
129
+
130
+ Base. show (io:: IO , x:: CustomArray ) = (print (io, " CustomArray" );show (io, x. x))
131
+ function Base. show (io:: IO , :: MIME"text/plain" , x:: CustomArray )
132
+ println (io, Base. summary (x), " :" )
133
+ Base. print_array (io, x. x)
134
+ end
135
+
136
+ prob = ODEProblem ((du, u, p, t) -> copyto! (du, u), CustomArray (ones (10 )), (0.0 , 10.0 ))
137
+ algs = [Tsit5 (), BS3 (), Vern9 (), DP5 ()]
138
+ for alg in algs
139
+ sol = @test_nowarn solve (prob, alg)
140
+ @test_nowarn sol (0.1 )
141
+ @test_nowarn sol (similar (prob. u0), 0.1 )
142
+ end
0 commit comments