Skip to content

Commit 7ca5ae9

Browse files
authored
Remove old reduction implementation (#589)
1 parent 2065794 commit 7ca5ae9

File tree

9 files changed

+23
-173
lines changed

9 files changed

+23
-173
lines changed

cubed/array_api/linear_algebra_functions.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cubed.core import blockwise, reduction, squeeze
1313

1414

15-
def matmul(x1, x2, /, use_new_impl=True, split_every=None):
15+
def matmul(x1, x2, /, split_every=None):
1616
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
1717
raise TypeError("Only numeric dtypes are allowed in matmul")
1818

@@ -51,9 +51,7 @@ def matmul(x1, x2, /, use_new_impl=True, split_every=None):
5151
dtype=dtype,
5252
)
5353

54-
out = _sum_wo_cat(
55-
out, axis=-2, dtype=dtype, use_new_impl=use_new_impl, split_every=split_every
56-
)
54+
out = _sum_wo_cat(out, axis=-2, dtype=dtype, split_every=split_every)
5755

5856
if x1_is_1d:
5957
out = squeeze(out, -2)
@@ -68,7 +66,7 @@ def _matmul(a, b):
6866
return chunk[..., nxp.newaxis, :]
6967

7068

71-
def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
69+
def _sum_wo_cat(a, axis=None, dtype=None, split_every=None):
7270
if a.shape[axis] == 1:
7371
return squeeze(a, axis)
7472

@@ -78,7 +76,6 @@ def _sum_wo_cat(a, axis=None, dtype=None, use_new_impl=True, split_every=None):
7876
_chunk_sum,
7977
axis=axis,
8078
dtype=dtype,
81-
use_new_impl=use_new_impl,
8279
split_every=split_every,
8380
extra_func_kwargs=extra_func_kwargs,
8481
)
@@ -99,7 +96,7 @@ def matrix_transpose(x, /):
9996
return permute_dims(x, axes)
10097

10198

102-
def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
99+
def tensordot(x1, x2, /, *, axes=2, split_every=None):
103100
from cubed.array_api.statistical_functions import sum
104101

105102
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
@@ -147,7 +144,6 @@ def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None):
147144
out,
148145
axis=x1_axes,
149146
dtype=dtype,
150-
use_new_impl=use_new_impl,
151147
split_every=split_every,
152148
)
153149

@@ -161,7 +157,7 @@ def _tensordot(a, b, axes):
161157
return x
162158

163159

164-
def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
160+
def vecdot(x1, x2, /, *, axis=-1, split_every=None):
165161
# based on the implementation in array-api-compat
166162
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
167163
raise TypeError("Only numeric dtypes are allowed in vecdot")
@@ -176,7 +172,6 @@ def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
176172
res = matmul(
177173
x1_[..., None, :],
178174
x2_[..., None],
179-
use_new_impl=use_new_impl,
180175
split_every=split_every,
181176
)
182177
return res[..., 0, 0]

cubed/array_api/searching_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from cubed.core.ops import arg_reduction, elemwise
66

77

8-
def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
8+
def argmax(x, /, *, axis=None, keepdims=False, split_every=None):
99
if x.dtype not in _real_numeric_dtypes:
1010
raise TypeError("Only real numeric dtypes are allowed in argmax")
1111
if axis is None:
@@ -17,12 +17,11 @@ def argmax(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
1717
nxp.argmax,
1818
axis=axis,
1919
keepdims=keepdims,
20-
use_new_impl=use_new_impl,
2120
split_every=split_every,
2221
)
2322

2423

25-
def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
24+
def argmin(x, /, *, axis=None, keepdims=False, split_every=None):
2625
if x.dtype not in _real_numeric_dtypes:
2726
raise TypeError("Only real numeric dtypes are allowed in argmin")
2827
if axis is None:
@@ -34,7 +33,6 @@ def argmin(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=No
3433
nxp.argmin,
3534
axis=axis,
3635
keepdims=keepdims,
37-
use_new_impl=use_new_impl,
3836
split_every=split_every,
3937
)
4038

cubed/array_api/statistical_functions.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
from cubed.core import reduction
1919

2020

21-
def max(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
21+
def max(x, /, *, axis=None, keepdims=False, split_every=None):
2222
if x.dtype not in _real_numeric_dtypes:
2323
raise TypeError("Only real numeric dtypes are allowed in max")
2424
return reduction(
2525
x,
2626
nxp.max,
2727
axis=axis,
2828
dtype=x.dtype,
29-
use_new_impl=use_new_impl,
3029
split_every=split_every,
3130
keepdims=keepdims,
3231
)
3332

3433

35-
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
34+
def mean(x, /, *, axis=None, keepdims=False, split_every=None):
3635
if x.dtype not in _real_floating_dtypes:
3736
raise TypeError("Only real floating-point dtypes are allowed in mean")
3837
# This implementation uses NumPy and Zarr's structured arrays to store a
@@ -53,7 +52,6 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None
5352
intermediate_dtype=intermediate_dtype,
5453
dtype=dtype,
5554
keepdims=keepdims,
56-
use_new_impl=use_new_impl,
5755
split_every=split_every,
5856
extra_func_kwargs=extra_func_kwargs,
5957
)
@@ -108,23 +106,20 @@ def _numel(x, **kwargs):
108106
return nxp.broadcast_to(nxp.asarray(prod, dtype=dtype), new_shape)
109107

110108

111-
def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
109+
def min(x, /, *, axis=None, keepdims=False, split_every=None):
112110
if x.dtype not in _real_numeric_dtypes:
113111
raise TypeError("Only real numeric dtypes are allowed in min")
114112
return reduction(
115113
x,
116114
nxp.min,
117115
axis=axis,
118116
dtype=x.dtype,
119-
use_new_impl=use_new_impl,
120117
split_every=split_every,
121118
keepdims=keepdims,
122119
)
123120

124121

125-
def prod(
126-
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
127-
):
122+
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
128123
# boolean is allowed by numpy
129124
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
130125
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
@@ -148,15 +143,12 @@ def prod(
148143
axis=axis,
149144
dtype=dtype,
150145
keepdims=keepdims,
151-
use_new_impl=use_new_impl,
152146
split_every=split_every,
153147
extra_func_kwargs=extra_func_kwargs,
154148
)
155149

156150

157-
def sum(
158-
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
159-
):
151+
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
160152
# boolean is allowed by numpy
161153
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
162154
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
@@ -180,7 +172,6 @@ def sum(
180172
axis=axis,
181173
dtype=dtype,
182174
keepdims=keepdims,
183-
use_new_impl=use_new_impl,
184175
split_every=split_every,
185176
extra_func_kwargs=extra_func_kwargs,
186177
)

cubed/array_api/utility_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from cubed.core import reduction
44

55

6-
def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
6+
def all(x, /, *, axis=None, keepdims=False, split_every=None):
77
if x.size == 0:
88
return asarray(True, dtype=x.dtype)
99
return reduction(
@@ -12,12 +12,11 @@ def all(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
1212
axis=axis,
1313
dtype=bool,
1414
keepdims=keepdims,
15-
use_new_impl=use_new_impl,
1615
split_every=split_every,
1716
)
1817

1918

20-
def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None):
19+
def any(x, /, *, axis=None, keepdims=False, split_every=None):
2120
if x.size == 0:
2221
return asarray(False, dtype=x.dtype)
2322
return reduction(
@@ -26,6 +25,5 @@ def any(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
2625
axis=axis,
2726
dtype=bool,
2827
keepdims=keepdims,
29-
use_new_impl=use_new_impl,
3028
split_every=split_every,
3129
)

cubed/core/groupby.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from cubed.array_api.manipulation_functions import broadcast_to, expand_dims
44
from cubed.backend_array_api import namespace as nxp
5-
from cubed.core.ops import map_blocks, map_direct, reduction_new
5+
from cubed.core.ops import map_blocks, map_direct, reduction
66
from cubed.utils import array_memory, get_item
77
from cubed.vendor.dask.array.core import normalize_chunks
88

@@ -105,7 +105,7 @@ def wrapper(a, by, **kwargs):
105105
out = expand_dims(out, axis=dummy_axis)
106106

107107
# then reduce across blocks
108-
return reduction_new(
108+
return reduction(
109109
out,
110110
func=None,
111111
combine_func=combine_func,

cubed/core/ops.py

Lines changed: 1 addition & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,122 +1056,6 @@ def key_function(out_key):
10561056

10571057

10581058
def reduction(
1059-
x: "Array",
1060-
func,
1061-
combine_func=None,
1062-
aggregate_func=None,
1063-
axis=None,
1064-
intermediate_dtype=None,
1065-
dtype=None,
1066-
keepdims=False,
1067-
use_new_impl=True,
1068-
split_every=None,
1069-
extra_func_kwargs=None,
1070-
) -> "Array":
1071-
"""Apply a function to reduce an array along one or more axes."""
1072-
if use_new_impl:
1073-
return reduction_new(
1074-
x,
1075-
func,
1076-
combine_func=combine_func,
1077-
aggregate_func=aggregate_func,
1078-
axis=axis,
1079-
intermediate_dtype=intermediate_dtype,
1080-
dtype=dtype,
1081-
keepdims=keepdims,
1082-
split_every=split_every,
1083-
extra_func_kwargs=extra_func_kwargs,
1084-
)
1085-
if combine_func is None:
1086-
combine_func = func
1087-
if axis is None:
1088-
axis = tuple(range(x.ndim))
1089-
if isinstance(axis, Integral):
1090-
axis = (axis,)
1091-
axis = validate_axis(axis, x.ndim)
1092-
if intermediate_dtype is None:
1093-
intermediate_dtype = dtype
1094-
1095-
inds = tuple(range(x.ndim))
1096-
1097-
result = x
1098-
allowed_mem = x.spec.allowed_mem
1099-
max_mem = allowed_mem - x.spec.reserved_mem
1100-
1101-
# reduce initial chunks
1102-
args = (result, inds)
1103-
adjust_chunks = {
1104-
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
1105-
}
1106-
result = blockwise(
1107-
func,
1108-
inds,
1109-
*args,
1110-
axis=axis,
1111-
keepdims=True,
1112-
dtype=intermediate_dtype,
1113-
adjust_chunks=adjust_chunks,
1114-
extra_func_kwargs=extra_func_kwargs,
1115-
)
1116-
1117-
# merge/reduce along axis in multiple rounds until there's a single block in each reduction axis
1118-
while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):
1119-
# merge along axis
1120-
target_chunks = list(result.chunksize)
1121-
chunk_mem = array_memory(intermediate_dtype, result.chunksize)
1122-
for i, s in enumerate(result.shape):
1123-
if i in axis:
1124-
assert result.chunksize[i] == 1 # result of reduction
1125-
if len(axis) > 1:
1126-
# multi-axis: don't exceed original chunksize in any reduction axis
1127-
# TODO: improve to use up to max_mem
1128-
target_chunks[i] = min(s, x.chunksize[i])
1129-
else:
1130-
# single axis: see how many result chunks fit in max_mem
1131-
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
1132-
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
1133-
if target_chunk_size <= 1:
1134-
raise ValueError(
1135-
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
1136-
)
1137-
target_chunks[i] = min(s, target_chunk_size)
1138-
_target_chunks = tuple(target_chunks)
1139-
result = merge_chunks(result, _target_chunks)
1140-
1141-
# reduce chunks (if any axis chunksize is > 1)
1142-
if any(s > 1 for i, s in enumerate(result.chunksize) if i in axis):
1143-
args = (result, inds)
1144-
adjust_chunks = {
1145-
i: (1,) * len(c) if i in axis else c
1146-
for i, c in enumerate(result.chunks)
1147-
}
1148-
result = blockwise(
1149-
combine_func,
1150-
inds,
1151-
*args,
1152-
axis=axis,
1153-
keepdims=True,
1154-
dtype=intermediate_dtype,
1155-
adjust_chunks=adjust_chunks,
1156-
extra_func_kwargs=extra_func_kwargs,
1157-
)
1158-
1159-
if aggregate_func is not None:
1160-
result = map_blocks(aggregate_func, result, dtype=dtype)
1161-
1162-
if not keepdims:
1163-
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
1164-
if len(axis_to_squeeze) > 0:
1165-
result = squeeze(result, axis_to_squeeze)
1166-
1167-
from cubed.array_api import astype
1168-
1169-
result = astype(result, dtype, copy=False)
1170-
1171-
return result
1172-
1173-
1174-
def reduction_new(
11751059
x: "Array",
11761060
func,
11771061
combine_func=None,
@@ -1426,9 +1310,7 @@ def _partial_reduce(arrays, reduce_func=None, initial_func=None, axis=None):
14261310
return result
14271311

14281312

1429-
def arg_reduction(
1430-
x, /, arg_func, axis=None, *, keepdims=False, use_new_impl=True, split_every=None
1431-
):
1313+
def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False, split_every=None):
14321314
"""A reduction that returns the array indexes, not the values."""
14331315
dtype = nxp.int64 # index data type
14341316
intermediate_dtype = [("i", dtype), ("v", x.dtype)]
@@ -1454,7 +1336,6 @@ def arg_reduction(
14541336
intermediate_dtype=intermediate_dtype,
14551337
dtype=dtype,
14561338
keepdims=keepdims,
1457-
use_new_impl=use_new_impl,
14581339
split_every=split_every,
14591340
)
14601341

0 commit comments

Comments
 (0)