2
2
3
3
import arrayfire_wrapper .dtypes as dtype
4
4
import arrayfire_wrapper .lib as wrapper
5
- from tests .utility_functions import check_type_supported , get_all_types , get_float_types , get_real_types
5
+ from tests .utility_functions import check_type_supported , get_all_types , get_real_types
6
+
6
7
7
8
@pytest .mark .parametrize (
8
9
"shape" ,
@@ -24,6 +25,7 @@ def test_accum_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
24
25
result = wrapper .accum (values , 0 )
25
26
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
26
27
28
+
27
29
@pytest .mark .parametrize (
28
30
"dim" ,
29
31
[
@@ -39,6 +41,8 @@ def test_accum_dims(dim: int) -> None:
39
41
values = wrapper .randu (shape , dtype .f32 )
40
42
result = wrapper .accum (values , dim )
41
43
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
44
+
45
+
42
46
@pytest .mark .parametrize (
43
47
"invdim" ,
44
48
[
@@ -54,6 +58,7 @@ def test_accum_invdims(invdim: int) -> None:
54
58
result = wrapper .accum (values , invdim )
55
59
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
56
60
61
+
57
62
@pytest .mark .parametrize (
58
63
"shape" ,
59
64
[
@@ -74,6 +79,7 @@ def test_scan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
74
79
result = wrapper .scan (values , 0 , wrapper .BinaryOperator .ADD , True )
75
80
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } , dtype { dtype_name } " # noqa
76
81
82
+
77
83
@pytest .mark .parametrize (
78
84
"dim" ,
79
85
[
@@ -89,6 +95,8 @@ def test_scan_dims(dim: int) -> None:
89
95
values = wrapper .randu (shape , dtype .f32 )
90
96
result = wrapper .scan (values , dim , wrapper .BinaryOperator .ADD , True )
91
97
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for dimension: { dim } " # noqa
98
+
99
+
92
100
@pytest .mark .parametrize (
93
101
"invdim" ,
94
102
[
@@ -103,6 +111,8 @@ def test_scan_invdims(invdim: int) -> None:
103
111
values = wrapper .randu (shape , dtype .f32 )
104
112
result = wrapper .scan (values , invdim , wrapper .BinaryOperator .ADD , True )
105
113
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
114
+
115
+
106
116
@pytest .mark .parametrize (
107
117
"binaryOp" ,
108
118
[
@@ -119,6 +129,7 @@ def test_scan_binaryOp(binaryOp: int) -> None:
119
129
result = wrapper .scan (values , 0 , wrapper .BinaryOperator (binaryOp ), True )
120
130
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for operation: { binaryOp } " # noqa
121
131
132
+
122
133
@pytest .mark .parametrize (
123
134
"shape" ,
124
135
[
@@ -133,12 +144,19 @@ def test_scan_binaryOp(binaryOp: int) -> None:
133
144
def test_scan_by_key_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
134
145
"""Test scan_by_key operation across all supported data types."""
135
146
check_type_supported (dtype_name )
136
- if dtype_name == dtype .f16 or dtype_name == dtype .f32 or dtype_name == dtype .uint16 or dtype_name == dtype .uint8 or dtype_name == dtype .int16 :
147
+ if (
148
+ dtype_name == dtype .f16
149
+ or dtype_name == dtype .f32
150
+ or dtype_name == dtype .uint16
151
+ or dtype_name == dtype .uint8
152
+ or dtype_name == dtype .int16
153
+ ):
137
154
pytest .skip ()
138
155
values = wrapper .randu (shape , dtype_name )
139
156
result = wrapper .scan_by_key (values , values , 0 , wrapper .BinaryOperator .ADD , True )
140
157
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } , dtype { dtype_name } " # noqa
141
158
159
+
142
160
@pytest .mark .parametrize (
143
161
"dim" ,
144
162
[
@@ -155,6 +173,7 @@ def test_scan_by_key_dims(dim: int) -> None:
155
173
result = wrapper .scan_by_key (values , values , dim , wrapper .BinaryOperator .ADD , True )
156
174
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for dimension: { dim } " # noqa
157
175
176
+
158
177
@pytest .mark .parametrize (
159
178
"invdim" ,
160
179
[
@@ -169,6 +188,8 @@ def test_scan_by_key_invdims(invdim: int) -> None:
169
188
values = wrapper .randu (shape , dtype .int32 )
170
189
result = wrapper .scan_by_key (values , values , invdim , wrapper .BinaryOperator .ADD , True )
171
190
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for shape: { shape } " # noqa
191
+
192
+
172
193
@pytest .mark .parametrize (
173
194
"binaryOp" ,
174
195
[
@@ -185,6 +206,7 @@ def test_scan_by_key_binaryOp(binaryOp: int) -> None:
185
206
result = wrapper .scan_by_key (values , values , 0 , wrapper .BinaryOperator (binaryOp ), True )
186
207
assert wrapper .get_dims (result )[0 : len (shape )] == shape , f"failed for operation: { binaryOp } " # noqa
187
208
209
+
188
210
@pytest .mark .parametrize (
189
211
"shape" ,
190
212
[
0 commit comments