Skip to content

Commit d6c2bb0

Browse files
tomwhitethodson-usgs
authored andcommitted
Cast inputs to Cubed arrays in apply_ufunc (#551)
* Cast inputs to Cubed arrays in `apply_ufunc` * Add comment about specs being the same
1 parent 46ef5cc commit d6c2bb0

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

cubed/core/gufunc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ def apply_gufunc(
6868
# Main code:
6969

7070
# Cast all input arrays to cubed
71-
# args = [asarray(a) for a in args] # TODO: do we need to support casting?
71+
# Use a spec if there is one. Note that all args have to have the same spec, and
72+
# this will be checked later when constructing the plan (see check_array_specs).
73+
from cubed.array_api.creation_functions import asarray
74+
75+
specs = [a.spec for a in args if hasattr(a, "spec")]
76+
spec = specs[0] if len(specs) > 0 else None
77+
args = [asarray(a, spec=spec) for a in args]
7278

7379
if len(input_coredimss) != len(args):
7480
raise ValueError(

cubed/tests/test_gufunc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ def add(x, y):
3838
assert_equal(z, np.array([2, 4, 6]))
3939

4040

41+
def test_apply_gufunc_elemwise_01_non_cubed_input(spec):
42+
def add(x, y):
43+
return x + y
44+
45+
a = cubed.from_array(np.array([1, 2, 3]), chunks=3, spec=spec)
46+
b = np.array([1, 2, 3])
47+
z = apply_gufunc(add, "(),()->()", a, b, output_dtypes=a.dtype)
48+
assert_equal(z, np.array([2, 4, 6]))
49+
50+
4151
def test_apply_gufunc_elemwise_loop(spec):
4252
def foo(x):
4353
assert x.shape in ((2,), (1,))

0 commit comments

Comments
 (0)