Skip to content

Commit 33f07eb

Browse files
authored
various fixes to array conversions (#668)
* 0-dimensional PyArray * gc safety and CartesianIndex support in PyArray * PyArray bounds checking * rm redundant methods * another simplification * no sum for empty tuples * fix docstring * consolidation of Array{PyObject} conversion, add a missing GC root, pysequence check fix
1 parent 0cb5c45 commit 33f07eb

File tree

6 files changed

+121
-109
lines changed

6 files changed

+121
-109
lines changed

src/conversions.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ function pyarray_dims(o::PyObject, forcelist=true)
369369
return () # too many non-List types can pretend to be sequences
370370
end
371371
len = ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
372+
len < 0 && error("not a PySequence object")
372373
if len == 0
373374
return (0,)
374375
end
@@ -392,12 +393,18 @@ function pyarray_dims(o::PyObject, forcelist=true)
392393
end
393394

394395
function py2array(T, o::PyObject)
395-
dims = pyarray_dims(o)
396+
b = PyBuffer()
397+
if isbuftype!(o, b)
398+
dims = size(b)
399+
else
400+
dims = pyarray_dims(o)
401+
end
402+
pydecref(b) # safe for immediate release
396403
A = Array{pyany_toany(T)}(undef, dims)
397-
py2array(T, A, o, 1, 1)
404+
py2array(T, A, o, 1, 1) # fixme: faster conversion for supported buffer types?
398405
end
399406

400-
function convert(::Type{Vector{T}}, o::PyObject) where T
407+
function py2vector(T, o::PyObject)
401408
len = ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
402409
if len < 0 || # not a sequence
403410
len+1 < 0 # object pretending to be a sequence of infinite length
@@ -406,6 +413,7 @@ function convert(::Type{Vector{T}}, o::PyObject) where T
406413
end
407414
py2array(T, Array{pyany_toany(T)}(undef, len), o, 1, 1)
408415
end
416+
convert(::Type{Vector{T}}, o::PyObject) where T = py2vector(T, o)
409417

410418
convert(::Type{Array}, o::PyObject) = map(identity, py2array(PyAny, o))
411419
convert(::Type{Array{T}}, o::PyObject) where {T} = py2array(T, o)
@@ -800,8 +808,8 @@ function pytype_query(o::PyObject, default::TypeTuple=PyObject)
800808
@return_not_None pyfunction_query(o)
801809
@return_not_None pydate_query(o)
802810
@return_not_None pydict_query(o)
803-
@return_not_None pysequence_query(o)
804811
@return_not_None pyptr_query(o)
812+
@return_not_None pysequence_query(o)
805813
@return_not_None pynothing_query(o)
806814
@return_not_None pymp_query(o)
807815
for (py,jl) in pytype_queries

src/pyarray.jl

Lines changed: 61 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ identical memory layout to a Julia `Array` of the same size.
5050
`st` should be the stride(s) *in bytes* between elements in each dimension
5151
"""
5252
function f_contiguous(::Type{T}, sz::NTuple{N,Int}, st::NTuple{N,Int}) where {T,N}
53+
N == 0 && return true # 0-dimensional arrays have 1 element, always contiguous
5354
if st[1] != sizeof(T)
5455
# not contiguous
5556
return false
@@ -153,77 +154,72 @@ function copy(a::PyArray{T,N}) where {T,N}
153154
return A
154155
end
155156

156-
# TODO: need to do bounds-checking of these indices!
157-
# TODO: need to GC root these `a`s to guard against the PyArray getting gc'd,
158-
# e.g. if it's a temporary in a function:
159-
# `two_rands() = pycall(np.rand, PyArray, 10)[1:2]`
160-
161-
162-
getindex(a::PyArray{T,0}) where {T} = unsafe_load(a.data)
163-
getindex(a::PyArray{T,1}, i::Integer) where {T} = unsafe_load(a.data, 1 + (i-1)*a.st[1])
157+
unsafe_data_load(a::PyArray, i::Integer) = GC.@preserve a unsafe_load(a.data, i)
158+
159+
@inline data_index(a::PyArray{<:Any,N}, i::CartesianIndex{N}) where {N} =
160+
1 + sum(ntuple(dim -> (i[dim]-1) * a.st[dim], Val{N}())) # Val lets julia unroll/inline
161+
data_index(a::PyArray{<:Any,0}, i::CartesianIndex{0}) = 1
162+
163+
# handle passing fewer/more indices than dimensions by canonicalizing to M==N
164+
@inline function fixindex(a::PyArray{<:Any,N}, i::CartesianIndex{M}) where {M,N}
165+
if M == N
166+
return i
167+
elseif M < N
168+
@boundscheck(all(ntuple(k -> size(a,k+M)==1, Val{N-M}())) ||
169+
throw(BoundsError(a, i))) # trailing sizes must == 1
170+
return CartesianIndex(Tuple(i)..., ntuple(k -> 1, Val{N-M}())...)
171+
else # M > N
172+
@boundscheck(all(ntuple(k -> i[k+N]==1, Val{M-N}())) ||
173+
throw(BoundsError(a, i))) # trailing indices must == 1
174+
return CartesianIndex(ntuple(k -> i[k], Val{N}()))
175+
end
176+
end
164177

165-
getindex(a::PyArray{T,2}, i::Integer, j::Integer) where {T} =
166-
unsafe_load(a.data, 1 + (i-1)*a.st[1] + (j-1)*a.st[2])
178+
@inline function getindex(a::PyArray, i::CartesianIndex)
179+
j = fixindex(a, i)
180+
@boundscheck checkbounds(a, j)
181+
unsafe_data_load(a, data_index(a, j))
182+
end
183+
@inline getindex(a::PyArray, i::Integer...) = a[CartesianIndex(i)]
184+
@inline getindex(a::PyArray{<:Any,1}, i::Integer) = a[CartesianIndex(i)]
167185

186+
# linear indexing
168187
function getindex(a::PyArray, i::Integer)
188+
@boundscheck checkbounds(a, i)
169189
if a.f_contig
170-
return unsafe_load(a.data, i)
190+
return unsafe_data_load(a, i)
171191
else
172-
return a[ind2sub(a.dims, i)...]
192+
@inbounds return a[CartesianIndices(a)[i]]
173193
end
174194
end
175195

176-
function getindex(a::PyArray, is::Integer...)
177-
index = 1
178-
n = min(length(is),length(a.st))
179-
for i = 1:n
180-
index += (is[i]-1)*a.st[i]
181-
end
182-
for i = n+1:length(is)
183-
if is[i] != 1
184-
throw(BoundsError())
185-
end
186-
end
187-
unsafe_load(a.data, index)
188-
end
189-
190196
function writeok_assign(a::PyArray, v, i::Integer)
191197
if a.info.readonly
192198
throw(ArgumentError("read-only PyArray"))
193199
else
194-
unsafe_store!(a.data, v, i)
200+
GC.@preserve a unsafe_store!(a.data, v, i)
195201
end
196-
return a
202+
return v
197203
end
198204

199-
setindex!(a::PyArray{T,0}, v) where {T} = writeok_assign(a, v, 1)
200-
setindex!(a::PyArray{T,1}, v, i::Integer) where {T} = writeok_assign(a, v, 1 + (i-1)*a.st[1])
201-
202-
setindex!(a::PyArray{T,2}, v, i::Integer, j::Integer) where {T} =
203-
writeok_assign(a, v, 1 + (i-1)*a.st[1] + (j-1)*a.st[2])
205+
@inline function setindex!(a::PyArray, v, i::CartesianIndex)
206+
j = fixindex(a, i)
207+
@boundscheck checkbounds(a, j)
208+
writeok_assign(a, v, data_index(a, j))
209+
end
210+
@inline setindex!(a::PyArray, v, i::Integer...) = setindex!(a, v, CartesianIndex(i))
211+
@inline setindex!(a::PyArray{<:Any,1}, v, i::Integer) = setindex!(a, v, CartesianIndex(i))
204212

213+
# linear indexing
205214
function setindex!(a::PyArray, v, i::Integer)
215+
@boundscheck checkbounds(a, i)
206216
if a.f_contig
207217
return writeok_assign(a, v, i)
208218
else
209-
return setindex!(a, v, ind2sub(a.dims, i)...)
219+
@inbounds return setindex!(a, v, CartesianIndices(a)[i])
210220
end
211221
end
212222

213-
function setindex!(a::PyArray, v, is::Integer...)
214-
index = 1
215-
n = min(length(is),length(a.st))
216-
for i = 1:n
217-
index += (is[i]-1)*a.st[i]
218-
end
219-
for i = n+1:length(is)
220-
if is[i] != 1
221-
throw(BoundsError())
222-
end
223-
end
224-
writeok_assign(a, v, index)
225-
end
226-
227223
stride(a::PyArray, i::Integer) = a.st[i]
228224

229225
Base.unsafe_convert(::Type{Ptr{T}}, a::PyArray{T}) where {T} = a.data
@@ -244,68 +240,56 @@ summary(a::PyArray{T}) where {T} = string(Base.dims2string(size(a)), " ",
244240
#########################################################################
245241
# PyArray <-> PyObject conversions
246242

247-
const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr}
243+
const PYARR_TYPES = Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float16,Float32,Float64,ComplexF32,ComplexF64,PyPtr,PyObject}
248244

249245
PyObject(a::PyArray) = a.o
250246

251247
convert(::Type{PyArray}, o::PyObject) = PyArray(o)
252248

249+
# PyObject arrays are created by taking a NumPy array of PyPtr and converting
250+
pyo2ptr(T::Type) = T
251+
pyo2ptr(::Type{PyObject}) = PyPtr
252+
pyocopy(a) = copy(a)
253+
pyocopy(a::AbstractArray{PyPtr}) = GC.@preserve a map(pyincref, a)
254+
253255
function convert(::Type{Array{T, 1}}, o::PyObject) where T<:PYARR_TYPES
254256
try
255-
copy(PyArray{T, 1}(o, PyArray_Info(o))) # will check T and N vs. info
257+
return pyocopy(PyArray{pyo2ptr(T), 1}(o, PyArray_Info(o))) # will check T and N vs. info
256258
catch
257-
len = @pycheckz ccall((@pysym :PySequence_Size), Int, (PyPtr,), o)
258-
A = Array{pyany_toany(T)}(undef, len)
259-
py2array(T, A, o, 1, 1)
259+
return py2vector(T, o)
260260
end
261261
end
262262

263263
function convert(::Type{Array{T}}, o::PyObject) where T<:PYARR_TYPES
264264
try
265265
info = PyArray_Info(o)
266266
try
267-
copy(PyArray{T, length(info.sz)}(o, info)) # will check T == eltype(info)
267+
return pyocopy(PyArray{pyo2ptr(T), length(info.sz)}(o, info)) # will check T == eltype(info)
268268
catch
269-
return py2array(T, Array{pyany_toany(T)}(undef, info.sz...), o, 1, 1)
269+
return py2array(T, Array{T}(undef, info.sz...), o, 1, 1)
270270
end
271271
catch
272-
py2array(T, o)
272+
return py2array(T, o)
273273
end
274274
end
275275

276276
function convert(::Type{Array{T,N}}, o::PyObject) where {T<:PYARR_TYPES,N}
277277
try
278278
info = PyArray_Info(o)
279279
try
280-
copy(PyArray{T,N}(o, info)) # will check T,N == eltype(info),ndims(info)
280+
pyocopy(PyArray{pyo2ptr(T),N}(o, info)) # will check T,N == eltype(info),ndims(info)
281281
catch
282282
nd = length(info.sz)
283-
if nd != N
284-
throw(ArgumentError("cannot convert $(nd)d array to $(N)d"))
285-
end
286-
return py2array(T, Array{pyany_toany(T)}(undef, info.sz...), o, 1, 1)
283+
nd == N || throw(ArgumentError("cannot convert $(nd)d array to $(N)d"))
284+
return py2array(T, Array{T}(undef, info.sz...), o, 1, 1)
287285
end
288286
catch
289287
A = py2array(T, o)
290-
if ndims(A) != N
291-
throw(ArgumentError("cannot convert $(ndims(A))d array to $(N)d"))
292-
end
293-
A
288+
ndims(A) == N || throw(ArgumentError("cannot convert $(ndims(A))d array to $(N)d"))
289+
return A
294290
end
295291
end
296292

297-
function convert(::Type{Array{PyObject}}, o::PyObject)
298-
map(pyincref, convert(Array{PyPtr}, o))
299-
end
300-
301-
function convert(::Type{Array{PyObject,1}}, o::PyObject)
302-
map(pyincref, convert(Array{PyPtr, 1}, o))
303-
end
304-
305-
function convert(::Type{Array{PyObject,N}}, o::PyObject) where N
306-
map(pyincref, convert(Array{PyPtr, N}, o))
307-
end
308-
309293
array_format(o::PyObject) = array_format(PyBuffer(o, PyBUF_ND_STRIDED))
310294

311295
"""

src/pybuffer.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ the python c-api function `PyObject_GetBuffer()`, unless o.obj is a PyPtr(C_NULL
4747
function pydecref(o::PyBuffer)
4848
# note that PyBuffer_Release sets o.obj to NULL, and
4949
# is a no-op if o.obj is already NULL
50-
# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
51-
_finalized[] || ccall(@pysym(:PyBuffer_Release), Cvoid, (Any,), o)
50+
_finalized[] || ccall(@pysym(:PyBuffer_Release), Cvoid, (Ref{PyBuffer},), o)
5251
o
5352
end
5453

@@ -96,10 +95,9 @@ end
9695
# Strides in bytes
9796
Base.strides(b::PyBuffer) = ((stride(b,i) for i in 1:b.buf.ndim)...,)
9897

99-
# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
10098
iscontiguous(b::PyBuffer) =
10199
1 == ccall((@pysym :PyBuffer_IsContiguous), Cint,
102-
(Any, Cchar), b, 'A')
100+
(Ref{PyBuffer}, Cchar), b, 'A')
103101

104102
#############################################################################
105103
# pybuffer constant values from Include/object.h
@@ -122,35 +120,33 @@ function PyBuffer(o::Union{PyObject,PyPtr}, flags=PyBUF_SIMPLE)
122120
end
123121

124122
function PyBuffer!(b::PyBuffer, o::Union{PyObject,PyPtr}, flags=PyBUF_SIMPLE)
125-
# TODO change to `Ref{PyBuffer}` when 0.6 is dropped.
126123
pydecref(b) # ensure b is properly released
127124
@pycheckz ccall((@pysym :PyObject_GetBuffer), Cint,
128-
(PyPtr, Any, Cint), o, b, flags)
125+
(PyPtr, Ref{PyBuffer}, Cint), o, b, flags)
129126
return b
130127
end
131128

132-
"""
133-
`isbuftype(o::Union{PyObject,PyPtr})`
134-
Returns true if the python object `o` supports the buffer protocol as a strided
135-
array. False if not.
136-
"""
137-
function isbuftype(o::Union{PyObject,PyPtr})
129+
# like isbuftype, but modifies caller's PyBuffer
130+
function isbuftype!(o::Union{PyObject,PyPtr}, b::PyBuffer)
138131
# PyObject_CheckBuffer is defined in a header file here: https://github.com/python/cpython/blob/ef5ce884a41c8553a7eff66ebace908c1dcc1f89/Include/abstract.h#L510
139132
# so we can't access it easily. It basically just checks if PyObject_GetBuffer exists
140133
# So we'll just try call PyObject_GetBuffer and check for success/failure
141-
b = PyBuffer()
142134
ret = ccall((@pysym :PyObject_GetBuffer), Cint,
143135
(PyPtr, Any, Cint), o, b, PyBUF_ND_STRIDED)
144136
if ret != 0
145137
pyerr_clear()
146-
else
147-
# handle pointer types
148-
T, native_byteorder = array_format(b)
149-
T <: Ptr && (ret = 1)
150138
end
151139
return ret == 0
152140
end
153141

142+
"""
143+
isbuftype(o::Union{PyObject,PyPtr})
144+
145+
Returns `true` if the python object `o` supports the buffer protocol as a strided
146+
array. `false` if not.
147+
"""
148+
isbuftype(o::Union{PyObject,PyPtr}) = isbuftype!(o, PyBuffer())
149+
154150
#############################################################################
155151

156152
# recursive function to write buffer dimension by dimension, starting at
@@ -195,7 +191,8 @@ end
195191
# ref: https://github.com/numpy/numpy/blob/v1.14.2/numpy/core/src/multiarray/buffer.c#L966
196192

197193
const standard_typestrs = Dict{String,DataType}(
198-
"?"=>Bool, "P"=>Ptr{Cvoid},
194+
"?"=>Bool,
195+
"P"=>Ptr{Cvoid}, "O"=>PyPtr,
199196
"b"=>Int8, "B"=>UInt8,
200197
"h"=>Int16, "H"=>UInt16,
201198
"i"=>Int32, "I"=>UInt32,
@@ -208,7 +205,8 @@ const standard_typestrs = Dict{String,DataType}(
208205
"Zf"=>ComplexF32, "Zd"=>ComplexF64)
209206

210207
const native_typestrs = Dict{String,DataType}(
211-
"?"=>Bool, "P"=>Ptr{Cvoid},
208+
"?"=>Bool,
209+
"P"=>Ptr{Cvoid}, "O"=>PyPtr,
212210
"b"=>Int8, "B"=>UInt8,
213211
"h"=>Cshort, "H"=>Cushort,
214212
"i"=>Cint, "I"=>Cuint,

src/pytype.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ function PyTypeObject!(init::Function, t::PyTypeObject, name::AbstractString, ba
312312
if t.tp_new == C_NULL
313313
t.tp_new = @pyglobal :PyType_GenericNew
314314
end
315-
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
316-
@pycheckz ccall((@pysym :PyType_Ready), Cint, (Any,), t)
315+
@pycheckz ccall((@pysym :PyType_Ready), Cint, (Ref{PyTypeObject},), t)
317316
ccall((@pysym :Py_IncRef), Cvoid, (Any,), t)
318317
return t
319318
end
@@ -414,8 +413,7 @@ const Py_TPFLAGS_HAVE_STACKLESS_EXTENSION = Ref(0x00000000)
414413
function pyjlwrap_type!(init::Function, to::PyTypeObject, name::AbstractString)
415414
sz = sizeof(Py_jlWrap) + sizeof(PyPtr) # must be > base type
416415
PyTypeObject!(to, name, sz) do t::PyTypeObject
417-
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
418-
t.tp_base = ccall(:jl_value_ptr, Ptr{Cvoid}, (Any,), jlWrapType)
416+
t.tp_base = ccall(:jl_value_ptr, Ptr{Cvoid}, (Ref{PyTypeObject},), jlWrapType)
419417
ccall((@pysym :Py_IncRef), Cvoid, (Any,), jlWrapType)
420418
init(t)
421419
end
@@ -426,9 +424,8 @@ pyjlwrap_type(init::Function, name::AbstractString) =
426424

427425
# Given a jlwrap type, create a new instance (and save value for gc)
428426
function pyjlwrap_new(pyT::PyTypeObject, value::Any)
429-
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
430427
o = PyObject(@pycheckn ccall((@pysym :_PyObject_New),
431-
PyPtr, (Any,), pyT))
428+
PyPtr, (Ref{PyTypeObject},), pyT))
432429
p = convert(Ptr{Ptr{Cvoid}}, PyPtr(o))
433430
if isimmutable(value)
434431
# It is undefined to call `pointer_from_objref` on immutable objects.
@@ -452,8 +449,7 @@ function pyjlwrap_new(x::Any)
452449
pyjlwrap_new(jlWrapType, x)
453450
end
454451

455-
# TODO change to `Ref{PyTypeObject}` when 0.6 is dropped.
456-
is_pyjlwrap(o::PyObject) = jlWrapType.tp_new != C_NULL && ccall((@pysym :PyObject_IsInstance), Cint, (PyPtr, Any), o, jlWrapType) == 1
452+
is_pyjlwrap(o::PyObject) = jlWrapType.tp_new != C_NULL && ccall((@pysym :PyObject_IsInstance), Cint, (PyPtr, Ref{PyTypeObject}), o, jlWrapType) == 1
457453

458454
################################################################
459455
# Fallback conversion: if we don't have a better conversion function,

0 commit comments

Comments
 (0)