|
| 1 | +import ctypes |
| 2 | +from functools import lru_cache |
| 3 | +from typing import Tuple |
| 4 | + |
| 5 | +import numba |
| 6 | +import numpy as np |
| 7 | +from numba import types |
| 8 | +from numba.core import cgutils |
| 9 | +from numba.core.extending import ( |
| 10 | + models, |
| 11 | + register_model, |
| 12 | +) |
| 13 | +from numba.core.typing.templates import AttributeTemplate |
| 14 | +from numba.cuda.cudadecl import registry as cuda_registry |
| 15 | +from numba.cuda.cudaimpl import registry as cuda_lower_registry |
| 16 | + |
| 17 | +from . import _iterators |
| 18 | + |
| 19 | + |
| 20 | +@lru_cache |
| 21 | +def make_iterator_struct_class(ndim): |
| 22 | + shape_ctype = ctypes.c_int64 * ndim |
| 23 | + strides_ctype = ctypes.c_int64 * ndim |
| 24 | + |
| 25 | + class StridedArrayView(ctypes.Structure): |
| 26 | + _fields_ = [ |
| 27 | + ("linear_id", ctypes.c_int64), |
| 28 | + ("ptr", ctypes.c_void_p), |
| 29 | + ("shape", shape_ctype), |
| 30 | + ("strides", strides_ctype), |
| 31 | + ("ndim", ctypes.c_int32), |
| 32 | + ] |
| 33 | + |
| 34 | + return StridedArrayView |
| 35 | + |
| 36 | + |
| 37 | +def iterator_struct_ctype(ptr: int, ndim: int, shape: Tuple[int], strides: Tuple[int]): |
| 38 | + StridedArrayView_cls = make_iterator_struct_class(ndim) |
| 39 | + |
| 40 | + c_shape = (ctypes.c_int64 * ndim)(*shape) |
| 41 | + c_strides = (ctypes.c_int64 * ndim)(*strides) |
| 42 | + return StridedArrayView_cls(0, ptr, c_shape, c_strides, ndim) |
| 43 | + |
| 44 | + |
| 45 | +@lru_cache |
| 46 | +def strided_view_iterator_numba_type(value_type: types.Type, ndim: int): |
| 47 | + """Returns the numba type that stores a typed pointer |
| 48 | + to record describing strided view into nd-array of |
| 49 | + elements with type `value_type` that has `ndim` dimensions. |
| 50 | +
|
| 51 | + On the host the struct would be accessed using |
| 52 | + ``StridedArrayView`` ctype class defined above. |
| 53 | + """ |
| 54 | + # ------ |
| 55 | + # Typing |
| 56 | + # ------ |
| 57 | + |
| 58 | + # View into strided device ndarray |
| 59 | + class NdArrayViewType(types.Type): |
| 60 | + def __init__(self): |
| 61 | + super(NdArrayViewType, self).__init__(name="NdArrayView") |
| 62 | + |
| 63 | + ndarray_view_type = NdArrayViewType() |
| 64 | + ptr_type = types.CPointer(ndarray_view_type) |
| 65 | + |
| 66 | + int64_numba_t = numba.from_dtype(np.int64) |
| 67 | + shape_arr_numba_t = types.UniTuple(int64_numba_t, ndim) |
| 68 | + strides_arr_numba_t = types.UniTuple(int64_numba_t, ndim) |
| 69 | + ndarray_view_members = [ |
| 70 | + ("linear_id", int64_numba_t), |
| 71 | + ("ptr", types.CPointer(value_type)), |
| 72 | + ("shape", shape_arr_numba_t), |
| 73 | + ("strides", strides_arr_numba_t), |
| 74 | + ("ndim", numba.from_dtype(np.int32)), |
| 75 | + ] |
| 76 | + |
| 77 | + # Typing for accessing attributes of the struct members |
| 78 | + class NdArrayViewAttrsTemplate(AttributeTemplate): |
| 79 | + pass |
| 80 | + |
| 81 | + def make_attr_resolver(ty): |
| 82 | + """ |
| 83 | + Function to capture a copy of **ty** argument in resolve function |
| 84 | + """ |
| 85 | + |
| 86 | + def resolve_fn(self, pp): |
| 87 | + return ty |
| 88 | + |
| 89 | + return resolve_fn |
| 90 | + |
| 91 | + for name, typ in ndarray_view_members: |
| 92 | + setattr(NdArrayViewAttrsTemplate, f"resolve_{name}", make_attr_resolver(typ)) |
| 93 | + |
| 94 | + @cuda_registry.register_attr |
| 95 | + class NdArrayViewAttrs(NdArrayViewAttrsTemplate): |
| 96 | + key = ndarray_view_type |
| 97 | + |
| 98 | + @cuda_registry.register_attr |
| 99 | + class PtrAttrs(AttributeTemplate): |
| 100 | + key = ptr_type |
| 101 | + |
| 102 | + def resolve_linear_id(self, pp): |
| 103 | + return int64_numba_t |
| 104 | + |
| 105 | + # ----------- |
| 106 | + # Data models |
| 107 | + # ----------- |
| 108 | + |
| 109 | + @register_model(NdArrayViewType) |
| 110 | + class NdArrayViewModel(models.StructModel): |
| 111 | + def __init__(self, dmm, fe_type): |
| 112 | + super().__init__(dmm, fe_type, ndarray_view_members) |
| 113 | + |
| 114 | + # -------- |
| 115 | + # Lowering |
| 116 | + # -------- |
| 117 | + |
| 118 | + @cuda_lower_registry.lower_getattr_generic(ndarray_view_type) |
| 119 | + def ndarray_view_getattr_lowering_fn(context, builder, sig, arg, attr): |
| 120 | + struct_values = cgutils.create_struct_proxy(ndarray_view_type)( |
| 121 | + context, builder, value=arg |
| 122 | + ) |
| 123 | + attr_ptr = struct_values._get_ptr_by_name(attr) |
| 124 | + attr_val = builder.load(attr_ptr) |
| 125 | + return attr_val |
| 126 | + |
| 127 | + @cuda_lower_registry.lower_setattr(ptr_type, "linear_id") |
| 128 | + def ndarray_view_pointer_set_linear_id(context, builder, sig, args): |
| 129 | + data = builder.load(args[0]) |
| 130 | + values = cgutils.create_struct_proxy(ndarray_view_type)( |
| 131 | + context, builder, value=data |
| 132 | + ) |
| 133 | + setattr(values, "linear_id", args[1]) |
| 134 | + return builder.store(values._getvalue(), args[0]) |
| 135 | + |
| 136 | + @cuda_lower_registry.lower_getattr(ptr_type, "linear_id") |
| 137 | + def ndarray_view_pointer_get_linear_id(context, builder, sig, arg): |
| 138 | + data = builder.load(arg) |
| 139 | + values = cgutils.create_struct_proxy(ndarray_view_type)( |
| 140 | + context, builder, value=data |
| 141 | + ) |
| 142 | + attr_ptr = values._get_ptr_by_name("linear_id") |
| 143 | + attr_val = builder.load(attr_ptr) |
| 144 | + return attr_val |
| 145 | + |
| 146 | + return ndarray_view_type |
| 147 | + |
| 148 | + |
| 149 | +class NdArrayIteratorKind(_iterators.IteratorKind): |
| 150 | + pass |
| 151 | + |
| 152 | + |
| 153 | +class NdArrayIterator(_iterators.IteratorBase): |
| 154 | + iterator_kind_type = NdArrayIteratorKind |
| 155 | + |
| 156 | + def __init__( |
| 157 | + self, ptr: int, value_type: types.Type, shape: Tuple[int], strides: Tuple[int] |
| 158 | + ): |
| 159 | + ndim = len(shape) |
| 160 | + if not (len(strides) == ndim): |
| 161 | + raise ValueError |
| 162 | + |
| 163 | + state_numba_type = strided_view_iterator_numba_type(value_type, ndim) |
| 164 | + numba_type = types.CPointer(state_numba_type) |
| 165 | + # build ctypes struct for state of iterator |
| 166 | + host_sav_cvalue = iterator_struct_ctype(ptr, ndim, shape, strides) |
| 167 | + super().__init__( |
| 168 | + cvalue=host_sav_cvalue, |
| 169 | + numba_type=numba_type, |
| 170 | + state_type=state_numba_type, |
| 171 | + value_type=value_type, |
| 172 | + ) |
| 173 | + |
| 174 | + @staticmethod |
| 175 | + def advance(state_ref, distance): |
| 176 | + state_ref.linear_id = state_ref.linear_id + distance |
| 177 | + |
| 178 | + @staticmethod |
| 179 | + def dereference(state_ref): |
| 180 | + state = state_ref[0] |
| 181 | + id_ = state.linear_id |
| 182 | + # init offset_ to zero of the same type as id_ |
| 183 | + offset_ = id_ - id_ |
| 184 | + ndim_ = state.ndim |
| 185 | + if ndim_ > 0: |
| 186 | + shape_ = state.shape |
| 187 | + strides_ = state.strides |
| 188 | + one_i32 = numba.int32(1) |
| 189 | + for i in range(one_i32, ndim_): |
| 190 | + bi_ = ndim_ - i |
| 191 | + sh_i = shape_[bi_] |
| 192 | + if sh_i > 0: |
| 193 | + q_ = id_ // sh_i |
| 194 | + r_ = id_ - q_ * sh_i |
| 195 | + else: |
| 196 | + q_ = id_ |
| 197 | + r_ = id_ - id_ # make zero of the right type |
| 198 | + offset_ = offset_ + r_ * strides_[bi_] |
| 199 | + id_ = q_ |
| 200 | + zero_i32 = one_i32 - one_i32 |
| 201 | + offset_ = offset_ + id_ * strides_[zero_i32] |
| 202 | + val = (state.ptr)[offset_] |
| 203 | + return val |
| 204 | + |
| 205 | + |
| 206 | +def make_ndarray_iterator(array_like, perm): |
| 207 | + ptr = array_like.data.ptr |
| 208 | + dt = array_like.dtype |
| 209 | + shape_ = array_like.shape |
| 210 | + strides_ = array_like.strides |
| 211 | + itemsize = array_like.itemsize |
| 212 | + perm_shape, perm_strides, rems = zip( |
| 213 | + *tuple( |
| 214 | + (shape_[idx], (strides_[idx] // itemsize), strides_[idx] % itemsize) |
| 215 | + for idx in perm |
| 216 | + ) |
| 217 | + ) |
| 218 | + assert all(rem == 0 for rem in rems) |
| 219 | + |
| 220 | + return NdArrayIterator(ptr, numba.from_dtype(dt), perm_shape, perm_strides) |
0 commit comments