@@ -50,6 +50,7 @@ identical memory layout to a Julia `Array` of the same size.
50
50
`st` should be the stride(s) *in bytes* between elements in each dimension
51
51
"""
52
52
function f_contiguous (:: Type{T} , sz:: NTuple{N,Int} , st:: NTuple{N,Int} ) where {T,N}
53
+ N == 0 && return true # 0-dimensional arrays have 1 element, always contiguous
53
54
if st[1 ] != sizeof (T)
54
55
# not contiguous
55
56
return false
@@ -153,77 +154,72 @@ function copy(a::PyArray{T,N}) where {T,N}
153
154
return A
154
155
end
155
156
156
- # TODO : need to do bounds-checking of these indices!
157
- # TODO : need to GC root these `a`s to guard against the PyArray getting gc'd,
158
- # e.g. if it's a temporary in a function:
159
- # `two_rands() = pycall(np.rand, PyArray, 10)[1:2]`
160
-
161
-
162
- getindex (a:: PyArray{T,0} ) where {T} = unsafe_load (a. data)
163
- getindex (a:: PyArray{T,1} , i:: Integer ) where {T} = unsafe_load (a. data, 1 + (i- 1 )* a. st[1 ])
157
+ unsafe_data_load (a:: PyArray , i:: Integer ) = GC. @preserve a unsafe_load (a. data, i)
158
+
159
+ @inline data_index (a:: PyArray{<:Any,N} , i:: CartesianIndex{N} ) where {N} =
160
+ 1 + sum (ntuple (dim -> (i[dim]- 1 ) * a. st[dim], Val {N} ())) # Val lets julia unroll/inline
161
+ data_index (a:: PyArray{<:Any,0} , i:: CartesianIndex{0} ) = 1
162
+
163
+ # handle passing fewer/more indices than dimensions by canonicalizing to M==N
164
+ @inline function fixindex (a:: PyArray{<:Any,N} , i:: CartesianIndex{M} ) where {M,N}
165
+ if M == N
166
+ return i
167
+ elseif M < N
168
+ @boundscheck (all (ntuple (k -> size (a,k+ M)== 1 , Val {N-M} ())) ||
169
+ throw (BoundsError (a, i))) # trailing sizes must == 1
170
+ return CartesianIndex (Tuple (i)... , ntuple (k -> 1 , Val {N-M} ())... )
171
+ else # M > N
172
+ @boundscheck (all (ntuple (k -> i[k+ N]== 1 , Val {M-N} ())) ||
173
+ throw (BoundsError (a, i))) # trailing indices must == 1
174
+ return CartesianIndex (ntuple (k -> i[k], Val {N} ()))
175
+ end
176
+ end
164
177
165
- getindex (a:: PyArray{T,2} , i:: Integer , j:: Integer ) where {T} =
166
- unsafe_load (a. data, 1 + (i- 1 )* a. st[1 ] + (j- 1 )* a. st[2 ])
178
+ @inline function getindex (a:: PyArray , i:: CartesianIndex )
179
+ j = fixindex (a, i)
180
+ @boundscheck checkbounds (a, j)
181
+ unsafe_data_load (a, data_index (a, j))
182
+ end
183
+ @inline getindex (a:: PyArray , i:: Integer... ) = a[CartesianIndex (i)]
184
+ @inline getindex (a:: PyArray{<:Any,1} , i:: Integer ) = a[CartesianIndex (i)]
167
185
186
+ # linear indexing
168
187
function getindex (a:: PyArray , i:: Integer )
188
+ @boundscheck checkbounds (a, i)
169
189
if a. f_contig
170
- return unsafe_load (a . data , i)
190
+ return unsafe_data_load (a , i)
171
191
else
172
- return a[ind2sub (a . dims, i) ... ]
192
+ @inbounds return a[CartesianIndices (a)[i] ]
173
193
end
174
194
end
175
195
176
- function getindex (a:: PyArray , is:: Integer... )
177
- index = 1
178
- n = min (length (is),length (a. st))
179
- for i = 1 : n
180
- index += (is[i]- 1 )* a. st[i]
181
- end
182
- for i = n+ 1 : length (is)
183
- if is[i] != 1
184
- throw (BoundsError ())
185
- end
186
- end
187
- unsafe_load (a. data, index)
188
- end
189
-
190
196
function writeok_assign (a:: PyArray , v, i:: Integer )
191
197
if a. info. readonly
192
198
throw (ArgumentError (" read-only PyArray" ))
193
199
else
194
- unsafe_store! (a. data, v, i)
200
+ GC . @preserve a unsafe_store! (a. data, v, i)
195
201
end
196
- return a
202
+ return v
197
203
end
198
204
199
- setindex! (a:: PyArray{T,0} , v) where {T} = writeok_assign (a, v, 1 )
200
- setindex! (a:: PyArray{T,1} , v, i:: Integer ) where {T} = writeok_assign (a, v, 1 + (i- 1 )* a. st[1 ])
201
-
202
- setindex! (a:: PyArray{T,2} , v, i:: Integer , j:: Integer ) where {T} =
203
- writeok_assign (a, v, 1 + (i- 1 )* a. st[1 ] + (j- 1 )* a. st[2 ])
205
+ @inline function setindex! (a:: PyArray , v, i:: CartesianIndex )
206
+ j = fixindex (a, i)
207
+ @boundscheck checkbounds (a, j)
208
+ writeok_assign (a, v, data_index (a, j))
209
+ end
210
+ @inline setindex! (a:: PyArray , v, i:: Integer... ) = setindex! (a, v, CartesianIndex (i))
211
+ @inline setindex! (a:: PyArray{<:Any,1} , v, i:: Integer ) = setindex! (a, v, CartesianIndex (i))
204
212
213
+ # linear indexing
205
214
function setindex! (a:: PyArray , v, i:: Integer )
215
+ @boundscheck checkbounds (a, i)
206
216
if a. f_contig
207
217
return writeok_assign (a, v, i)
208
218
else
209
- return setindex! (a, v, ind2sub (a . dims, i) ... )
219
+ @inbounds return setindex! (a, v, CartesianIndices (a)[i] )
210
220
end
211
221
end
212
222
213
- function setindex! (a:: PyArray , v, is:: Integer... )
214
- index = 1
215
- n = min (length (is),length (a. st))
216
- for i = 1 : n
217
- index += (is[i]- 1 )* a. st[i]
218
- end
219
- for i = n+ 1 : length (is)
220
- if is[i] != 1
221
- throw (BoundsError ())
222
- end
223
- end
224
- writeok_assign (a, v, index)
225
- end
226
-
227
223
stride (a:: PyArray , i:: Integer ) = a. st[i]
228
224
229
225
Base. unsafe_convert (:: Type{Ptr{T}} , a:: PyArray{T} ) where {T} = a. data
@@ -244,68 +240,56 @@ summary(a::PyArray{T}) where {T} = string(Base.dims2string(size(a)), " ",
244
240
# ########################################################################
245
241
# PyArray <-> PyObject conversions
246
242
247
- const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr}
243
+ const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr,PyObject }
248
244
249
245
PyObject (a:: PyArray ) = a. o
250
246
251
247
convert (:: Type{PyArray} , o:: PyObject ) = PyArray (o)
252
248
249
+ # PyObject arrays are created by taking a NumPy array of PyPtr and converting
250
+ pyo2ptr (T:: Type ) = T
251
+ pyo2ptr (:: Type{PyObject} ) = PyPtr
252
+ pyocopy (a) = copy (a)
253
+ pyocopy (a:: AbstractArray{PyPtr} ) = GC. @preserve a map (pyincref, a)
254
+
253
255
function convert (:: Type{Array{T, 1}} , o:: PyObject ) where T<: PYARR_TYPES
254
256
try
255
- copy (PyArray {T , 1} (o, PyArray_Info (o))) # will check T and N vs. info
257
+ return pyocopy (PyArray {pyo2ptr(T) , 1} (o, PyArray_Info (o))) # will check T and N vs. info
256
258
catch
257
- len = @pycheckz ccall ((@pysym :PySequence_Size ), Int, (PyPtr,), o)
258
- A = Array {pyany_toany(T)} (undef, len)
259
- py2array (T, A, o, 1 , 1 )
259
+ return py2vector (T, o)
260
260
end
261
261
end
262
262
263
263
function convert (:: Type{Array{T}} , o:: PyObject ) where T<: PYARR_TYPES
264
264
try
265
265
info = PyArray_Info (o)
266
266
try
267
- copy (PyArray {T , length(info.sz)} (o, info)) # will check T == eltype(info)
267
+ return pyocopy (PyArray {pyo2ptr(T) , length(info.sz)} (o, info)) # will check T == eltype(info)
268
268
catch
269
- return py2array (T, Array {pyany_toany(T) } (undef, info. sz... ), o, 1 , 1 )
269
+ return py2array (T, Array {T } (undef, info. sz... ), o, 1 , 1 )
270
270
end
271
271
catch
272
- py2array (T, o)
272
+ return py2array (T, o)
273
273
end
274
274
end
275
275
276
276
function convert (:: Type{Array{T,N}} , o:: PyObject ) where {T<: PYARR_TYPES ,N}
277
277
try
278
278
info = PyArray_Info (o)
279
279
try
280
- copy (PyArray {T ,N} (o, info)) # will check T,N == eltype(info),ndims(info)
280
+ pyocopy (PyArray {pyo2ptr(T) ,N} (o, info)) # will check T,N == eltype(info),ndims(info)
281
281
catch
282
282
nd = length (info. sz)
283
- if nd != N
284
- throw (ArgumentError (" cannot convert $(nd) d array to $(N) d" ))
285
- end
286
- return py2array (T, Array {pyany_toany(T)} (undef, info. sz... ), o, 1 , 1 )
283
+ nd == N || throw (ArgumentError (" cannot convert $(nd) d array to $(N) d" ))
284
+ return py2array (T, Array {T} (undef, info. sz... ), o, 1 , 1 )
287
285
end
288
286
catch
289
287
A = py2array (T, o)
290
- if ndims (A) != N
291
- throw (ArgumentError (" cannot convert $(ndims (A)) d array to $(N) d" ))
292
- end
293
- A
288
+ ndims (A) == N || throw (ArgumentError (" cannot convert $(ndims (A)) d array to $(N) d" ))
289
+ return A
294
290
end
295
291
end
296
292
297
- function convert (:: Type{Array{PyObject}} , o:: PyObject )
298
- map (pyincref, convert (Array{PyPtr}, o))
299
- end
300
-
301
- function convert (:: Type{Array{PyObject,1}} , o:: PyObject )
302
- map (pyincref, convert (Array{PyPtr, 1 }, o))
303
- end
304
-
305
- function convert (:: Type{Array{PyObject,N}} , o:: PyObject ) where N
306
- map (pyincref, convert (Array{PyPtr, N}, o))
307
- end
308
-
309
293
array_format (o:: PyObject ) = array_format (PyBuffer (o, PyBUF_ND_STRIDED))
310
294
311
295
"""
0 commit comments