Skip to content

Commit 4fb30c8

Browse files
authored
Fix dtypes for qr (#594)
1 parent d361e00 commit 4fb30c8

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

cubed/array_api/linalg.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# These functions are in both the main and linalg namespaces
66
from cubed.array_api.data_type_functions import result_type
7+
from cubed.array_api.dtypes import _floating_dtypes
78
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
89
matmul,
910
matrix_transpose,
@@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult:
3334
if mode != "reduced":
3435
raise ValueError("qr only supports mode='reduced'")
3536

37+
if x.dtype not in _floating_dtypes:
38+
raise TypeError("Only floating-point dtypes are allowed in qr")
39+
3640
if x.numblocks[1] > 1:
3741
raise ValueError(
3842
"qr only supports tall-and-skinny (single column chunk) arrays. "
@@ -80,7 +84,7 @@ def _qr_first_step(A):
8084
nxp.linalg.qr,
8185
A,
8286
shapes=[A.shape, R1_shape],
83-
dtypes=[nxp.float64, nxp.float64],
87+
dtypes=[A.dtype, A.dtype],
8488
chunkss=[A.chunks, R1_chunks],
8589
extra_projected_mem=extra_projected_mem,
8690
)
@@ -119,7 +123,7 @@ def _qr_second_step(R1):
119123
nxp.linalg.qr,
120124
R1_single,
121125
shapes=[Q2_shape, R2_shape],
122-
dtypes=[nxp.float64, nxp.float64],
126+
dtypes=[R1.dtype, R1.dtype],
123127
chunkss=[Q2_chunks, R2_chunks],
124128
extra_projected_mem=extra_projected_mem,
125129
)
@@ -148,7 +152,7 @@ def _qr_third_step(Q1, Q2):
148152
Q1,
149153
Q2,
150154
shape=Q1_shape,
151-
dtype=nxp.float64,
155+
dtype=result_type(Q1, Q2),
152156
chunks=Q1_chunks,
153157
extra_projected_mem=extra_projected_mem,
154158
q1_chunks=Q1_chunks,

0 commit comments

Comments
 (0)