Skip to content

Commit ce8f61b

Browse files
alxmrstomwhite
andauthored
Using default_dtypes instead of hard-coding dtypes. (#666)
* First pass for using `default-dtypes` instead of hard-coding. Locally, most of the tests pass. This adds default dtypes to newly created arrays. * Updated several creation functions. All tests passing. * Fixed comment/added type update to creation_functions.py. * Great suggestion from @tomwhite, removed unecessary complexity from a creation function and fixed something broken in my code. * Making a utility function private, moving it to a better place. * Oops: didn't delete import. * Passing lint (via pre-commit). * Using default_dtypes in statistical_functions.py. Have a few questions. * Intermediate values use default dtypes in statistical_functions.py. * Move private function to the bottom of the file. * Revert "Intermediate values use default dtypes in statistical_functions.py." This reverts commit 6a2f60f. * More accurate, non-jax specific comment. * Better default dtype for unsigned ints, thanks Claude. * Omitting unnecessary function. * Extracted dtype validate method into dtypes.py. Cross applied to nan_functions.py. * Copy/editing Co-authored-by: Tom White <tom.e.white@gmail.com> * Implemented feedback from @tomwhite. --------- Co-authored-by: Tom White <tom.e.white@gmail.com>
1 parent 72a05b4 commit ce8f61b

File tree

4 files changed

+71
-69
lines changed

4 files changed

+71
-69
lines changed

cubed/array_api/creation_functions.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
from typing import TYPE_CHECKING, Iterable, List
33

4+
from cubed.array_api import __array_namespace_info__
45
from cubed.backend_array_api import namespace as nxp
56
from cubed.core import Plan, gensym
67
from cubed.core.ops import map_blocks
@@ -25,6 +26,7 @@ def arange(
2526
num = int(max(math.ceil((stop - start) / step), 0))
2627
if dtype is None:
2728
dtype = nxp.arange(start, stop, step * num if num else step).dtype
29+
2830
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
2931
chunksize = chunks[0][0]
3032

@@ -62,8 +64,8 @@ def asarray(
6264
): # pragma: no cover
6365
return asarray(a.data)
6466
elif not isinstance(getattr(a, "shape", None), Iterable):
65-
# ensure blocks are arrays
6667
a = nxp.asarray(a, dtype=dtype)
68+
6769
if dtype is None:
6870
dtype = a.dtype
6971

@@ -89,8 +91,9 @@ def empty_like(x, /, *, dtype=None, device=None, chunks=None, spec=None) -> "Arr
8991
def empty_virtual_array(
9092
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
9193
) -> "Array":
94+
dtypes = __array_namespace_info__().default_dtypes(device=device)
9295
if dtype is None:
93-
dtype = nxp.float64
96+
dtype = dtypes["real floating"]
9497

9598
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
9699
name = gensym()
@@ -105,10 +108,11 @@ def empty_virtual_array(
105108
def eye(
106109
n_rows, n_cols=None, /, *, k=0, dtype=None, device=None, chunks="auto", spec=None
107110
) -> "Array":
111+
dtypes = __array_namespace_info__().default_dtypes(device=device)
108112
if n_cols is None:
109113
n_cols = n_rows
110114
if dtype is None:
111-
dtype = nxp.float64
115+
dtype = dtypes["real floating"]
112116

113117
shape = (n_rows, n_cols)
114118
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
@@ -136,17 +140,18 @@ def _eye(x, k=None, chunksize=None, block_id=None):
136140
def full(
137141
shape, fill_value, *, dtype=None, device=None, chunks="auto", spec=None
138142
) -> "Array":
143+
dtypes = __array_namespace_info__().default_dtypes(device=device)
139144
shape = normalize_shape(shape)
140145
if dtype is None:
141146
# check bool first since True/False are instances of int and float
142147
if isinstance(fill_value, bool):
143148
dtype = nxp.bool
144149
elif isinstance(fill_value, int):
145-
dtype = nxp.int64
150+
dtype = dtypes["integral"]
146151
elif isinstance(fill_value, float):
147-
dtype = nxp.float64
152+
dtype = dtypes["real floating"]
148153
elif isinstance(fill_value, complex):
149-
dtype = nxp.complex128
154+
dtype = dtypes["complex floating"]
150155
else:
151156
raise TypeError("Invalid input to full")
152157
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
@@ -187,13 +192,15 @@ def linspace(
187192
chunks="auto",
188193
spec=None,
189194
) -> "Array":
195+
dtypes = __array_namespace_info__().default_dtypes(device=device)
196+
190197
range_ = stop - start
191198
div = (num - 1) if endpoint else num
192199
if div == 0:
193200
div = 1
194201
step = float(range_) / div
195202
if dtype is None:
196-
dtype = nxp.float64
203+
dtype = dtypes["real floating"]
197204
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
198205
chunksize = chunks[0][0]
199206

@@ -210,15 +217,23 @@ def linspace(
210217
step=step,
211218
endpoint=endpoint,
212219
linspace_dtype=dtype,
220+
device=device,
213221
)
214222

215223

216-
def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None):
224+
def _linspace(
225+
x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None
226+
):
227+
dtypes = __array_namespace_info__().default_dtypes(device=device)
228+
217229
bs = x.shape[0]
218230
i = block_id[0]
219231
adjusted_bs = bs - 1 if endpoint else bs
220-
blockstart = start + (i * size * step)
221-
blockstop = blockstart + (adjusted_bs * step)
232+
233+
# float_ is a type casting function.
234+
float_ = dtypes["real floating"].type
235+
blockstart = float_(start + (i * size * step))
236+
blockstop = float_(blockstart + float_(adjusted_bs * step))
222237
return nxp.linspace(
223238
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
224239
)
@@ -256,8 +271,10 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:
256271

257272

258273
def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
274+
dtypes = __array_namespace_info__().default_dtypes(device=device)
275+
259276
if dtype is None:
260-
dtype = nxp.float64
277+
dtype = dtypes["real floating"]
261278
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)
262279

263280

@@ -302,8 +319,10 @@ def _tri_mask(N, M, k, chunks, spec):
302319

303320

304321
def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
322+
dtypes = __array_namespace_info__().default_dtypes(device=device)
323+
305324
if dtype is None:
306-
dtype = nxp.float64
325+
dtype = dtypes["real floating"]
307326
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)
308327

309328

cubed/array_api/dtypes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copied from numpy.array_api
2+
from cubed.array_api.inspection import __array_namespace_info__
23
from cubed.backend_array_api import namespace as nxp
34

45
int8 = nxp.int8
@@ -86,3 +87,32 @@
8687
"complex floating-point": _complex_floating_dtypes,
8788
"floating-point": _floating_dtypes,
8889
}
90+
91+
92+
# A Cubed-specific utility.
93+
def _upcast_integral_dtypes(x, dtype=None, *, allowed_dtypes=("numeric",), fname=None, device=None):
94+
"""Ensure the input dtype is allowed. If it's None, provide a good default dtype."""
95+
dtypes = __array_namespace_info__().default_dtypes(device=device)
96+
97+
# Validate.
98+
is_invalid = all(x.dtype not in _dtype_categories[a] for a in allowed_dtypes)
99+
if is_invalid:
100+
errmsg = f"Only {' or '.join(allowed_dtypes)} dtypes are allowed"
101+
if fname:
102+
errmsg += f" in {fname}"
103+
raise TypeError(errmsg)
104+
105+
# Choose a good default dtype, when None
106+
if dtype is None:
107+
if x.dtype in _boolean_dtypes:
108+
dtype = dtypes["integral"]
109+
elif x.dtype in _signed_integer_dtypes:
110+
dtype = dtypes["integral"]
111+
elif x.dtype in _unsigned_integer_dtypes:
112+
# Type arithmetic to produce an unsigned integer dtype at the same default precision.
113+
default_bits = nxp.iinfo(dtypes["integral"]).bits
114+
dtype = nxp.dtype(f"u{default_bits // 8}")
115+
else:
116+
dtype = x.dtype
117+
118+
return dtype

cubed/array_api/statistical_functions.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import math
22

33
from cubed.array_api.dtypes import (
4-
_boolean_dtypes,
5-
_numeric_dtypes,
64
_real_floating_dtypes,
75
_real_numeric_dtypes,
8-
_signed_integer_dtypes,
9-
_unsigned_integer_dtypes,
10-
int64,
11-
uint64,
6+
_upcast_integral_dtypes,
127
)
138
from cubed.array_api.elementwise_functions import sqrt
149
from cubed.backend_array_api import namespace as nxp
@@ -35,6 +30,7 @@ def mean(x, /, *, axis=None, keepdims=False, split_every=None):
3530
# pair of fields needed to keep per-chunk counts and totals for computing
3631
# the mean.
3732
dtype = x.dtype
33+
#TODO(#658): Should these be default dtypes?
3834
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
3935
extra_func_kwargs = dict(dtype=intermediate_dtype)
4036
return reduction(
@@ -113,19 +109,8 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None):
113109
)
114110

115111

116-
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
117-
# boolean is allowed by numpy
118-
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
119-
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
120-
if dtype is None:
121-
if x.dtype in _boolean_dtypes:
122-
dtype = int64
123-
elif x.dtype in _signed_integer_dtypes:
124-
dtype = int64
125-
elif x.dtype in _unsigned_integer_dtypes:
126-
dtype = uint64
127-
else:
128-
dtype = x.dtype
112+
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
113+
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="prod", device=device)
129114
extra_func_kwargs = dict(dtype=dtype)
130115
return reduction(
131116
x,
@@ -150,19 +135,8 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):
150135
)
151136

152137

153-
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
154-
# boolean is allowed by numpy
155-
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
156-
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
157-
if dtype is None:
158-
if x.dtype in _boolean_dtypes:
159-
dtype = int64
160-
elif x.dtype in _signed_integer_dtypes:
161-
dtype = int64
162-
elif x.dtype in _unsigned_integer_dtypes:
163-
dtype = uint64
164-
else:
165-
dtype = x.dtype
138+
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
139+
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="sum", device=device)
166140
extra_func_kwargs = dict(dtype=dtype)
167141
return reduction(
168142
x,
@@ -189,6 +163,7 @@ def var(
189163
if x.dtype not in _real_floating_dtypes:
190164
raise TypeError("Only real floating-point dtypes are allowed in var")
191165
dtype = x.dtype
166+
#TODO(#658): Should these be default dtypes?
192167
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
193168
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
194169
return reduction(

cubed/nan_functions.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,6 @@
11
import numpy as np
22

3-
from cubed.array_api.dtypes import (
4-
_numeric_dtypes,
5-
_signed_integer_dtypes,
6-
_unsigned_integer_dtypes,
7-
complex64,
8-
complex128,
9-
float32,
10-
float64,
11-
int64,
12-
uint64,
13-
)
3+
from cubed.array_api.dtypes import _upcast_integral_dtypes
144
from cubed.backend_array_api import namespace as nxp
155
from cubed.core import reduction
166

@@ -60,21 +50,9 @@ def _nannumel(x, **kwargs):
6050
return nxp.sum(~(nxp.isnan(x)), **kwargs)
6151

6252

63-
def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
53+
def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
6454
"""Return the sum of array elements over a given axis treating NaNs as zero."""
65-
if x.dtype not in _numeric_dtypes:
66-
raise TypeError("Only numeric dtypes are allowed in nansum")
67-
if dtype is None:
68-
if x.dtype in _signed_integer_dtypes:
69-
dtype = int64
70-
elif x.dtype in _unsigned_integer_dtypes:
71-
dtype = uint64
72-
elif x.dtype == float32:
73-
dtype = float64
74-
elif x.dtype == complex64:
75-
dtype = complex128
76-
else:
77-
dtype = x.dtype
55+
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric",), fname="nansum", device=device)
7856
return reduction(
7957
x,
8058
nxp.nansum,

0 commit comments

Comments
 (0)