|
4 | 4 |
|
5 | 5 | # These functions are in both the main and linalg namespaces
|
6 | 6 | from cubed.array_api.data_type_functions import result_type
|
| 7 | +from cubed.array_api.dtypes import _floating_dtypes |
7 | 8 | from cubed.array_api.linear_algebra_functions import ( # noqa: F401
|
8 | 9 | matmul,
|
9 | 10 | matrix_transpose,
|
@@ -33,6 +34,9 @@ def qr(x, /, *, mode="reduced") -> QRResult:
|
33 | 34 | if mode != "reduced":
|
34 | 35 | raise ValueError("qr only supports mode='reduced'")
|
35 | 36 |
|
| 37 | + if x.dtype not in _floating_dtypes: |
| 38 | + raise TypeError("Only floating-point dtypes are allowed in qr") |
| 39 | + |
36 | 40 | if x.numblocks[1] > 1:
|
37 | 41 | raise ValueError(
|
38 | 42 | "qr only supports tall-and-skinny (single column chunk) arrays. "
|
@@ -80,7 +84,7 @@ def _qr_first_step(A):
|
80 | 84 | nxp.linalg.qr,
|
81 | 85 | A,
|
82 | 86 | shapes=[A.shape, R1_shape],
|
83 |
| - dtypes=[nxp.float64, nxp.float64], |
| 87 | + dtypes=[A.dtype, A.dtype], |
84 | 88 | chunkss=[A.chunks, R1_chunks],
|
85 | 89 | extra_projected_mem=extra_projected_mem,
|
86 | 90 | )
|
@@ -119,7 +123,7 @@ def _qr_second_step(R1):
|
119 | 123 | nxp.linalg.qr,
|
120 | 124 | R1_single,
|
121 | 125 | shapes=[Q2_shape, R2_shape],
|
122 |
| - dtypes=[nxp.float64, nxp.float64], |
| 126 | + dtypes=[R1.dtype, R1.dtype], |
123 | 127 | chunkss=[Q2_chunks, R2_chunks],
|
124 | 128 | extra_projected_mem=extra_projected_mem,
|
125 | 129 | )
|
@@ -148,7 +152,7 @@ def _qr_third_step(Q1, Q2):
|
148 | 152 | Q1,
|
149 | 153 | Q2,
|
150 | 154 | shape=Q1_shape,
|
151 |
| - dtype=nxp.float64, |
| 155 | + dtype=result_type(Q1, Q2), |
152 | 156 | chunks=Q1_chunks,
|
153 | 157 | extra_projected_mem=extra_projected_mem,
|
154 | 158 | q1_chunks=Q1_chunks,
|
|
0 commit comments