@@ -33,7 +33,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
33
33
assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
34
34
assert r .sycl_queue == ar1 .sycl_queue
35
35
36
- out = dpt .empty_like (ar1 , dtype = expected_dtype )
36
+ out = dpt .empty_like (ar1 , dtype = r . dtype )
37
37
dpt .add (ar1 , ar2 , out )
38
38
assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
39
39
@@ -49,7 +49,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
49
49
assert r .shape == ar3 .shape
50
50
assert (dpt .asnumpy (r ) == np .full (r .shape , 2 , dtype = r .dtype )).all ()
51
51
52
- out = dpt .empty_like (ar1 , dtype = expected_dtype )
52
+ out = dpt .empty_like (ar1 , dtype = r . dtype )
53
53
dpt .add (ar3 [::- 1 ], ar4 [::2 ], out )
54
54
assert (dpt .asnumpy (out ) == np .full (out .shape , 2 , dtype = out .dtype )).all ()
55
55
@@ -74,37 +74,49 @@ def test_add_usm_type_matrix(op1_usm_type, op2_usm_type):
74
74
def test_add_order ():
75
75
get_queue_or_skip ()
76
76
77
- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
78
- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
79
- r1 = dpt .add (ar1 , ar2 , order = "C" )
80
- assert r1 .flags .c_contiguous
81
- r2 = dpt .add (ar1 , ar2 , order = "F" )
82
- assert r2 .flags .f_contiguous
83
- r3 = dpt .add (ar1 , ar2 , order = "A" )
84
- assert r3 .flags .c_contiguous
85
- r4 = dpt .add (ar1 , ar2 , order = "K" )
86
- assert r4 .flags .c_contiguous
87
-
88
- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
89
- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
90
- r1 = dpt .add (ar1 , ar2 , order = "C" )
91
- assert r1 .flags .c_contiguous
92
- r2 = dpt .add (ar1 , ar2 , order = "F" )
93
- assert r2 .flags .f_contiguous
94
- r3 = dpt .add (ar1 , ar2 , order = "A" )
95
- assert r3 .flags .f_contiguous
96
- r4 = dpt .add (ar1 , ar2 , order = "K" )
97
- assert r4 .flags .f_contiguous
98
-
99
- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
100
- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
101
- r4 = dpt .add (ar1 , ar2 , order = "K" )
102
- assert r4 .strides == (20 , - 1 )
103
-
104
- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
105
- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
106
- r4 = dpt .add (ar1 , ar2 , order = "K" )
107
- assert r4 .strides == (- 1 , 20 )
77
+ test_shape = (
78
+ 20 ,
79
+ 20 ,
80
+ )
81
+ test_shape2 = tuple (2 * dim for dim in test_shape )
82
+ n = test_shape [- 1 ]
83
+
84
+ for dt1 , dt2 in zip (["i4" , "i4" , "f4" ], ["i4" , "f4" , "i4" ]):
85
+ ar1 = dpt .ones (test_shape , dtype = dt1 , order = "C" )
86
+ ar2 = dpt .ones (test_shape , dtype = dt2 , order = "C" )
87
+ r1 = dpt .add (ar1 , ar2 , order = "C" )
88
+ assert r1 .flags .c_contiguous
89
+ r2 = dpt .add (ar1 , ar2 , order = "F" )
90
+ assert r2 .flags .f_contiguous
91
+ r3 = dpt .add (ar1 , ar2 , order = "A" )
92
+ assert r3 .flags .c_contiguous
93
+ r4 = dpt .add (ar1 , ar2 , order = "K" )
94
+ assert r4 .flags .c_contiguous
95
+
96
+ ar1 = dpt .ones (test_shape , dtype = dt1 , order = "F" )
97
+ ar2 = dpt .ones (test_shape , dtype = dt2 , order = "F" )
98
+ r1 = dpt .add (ar1 , ar2 , order = "C" )
99
+ assert r1 .flags .c_contiguous
100
+ r2 = dpt .add (ar1 , ar2 , order = "F" )
101
+ assert r2 .flags .f_contiguous
102
+ r3 = dpt .add (ar1 , ar2 , order = "A" )
103
+ assert r3 .flags .f_contiguous
104
+ r4 = dpt .add (ar1 , ar2 , order = "K" )
105
+ assert r4 .flags .f_contiguous
106
+
107
+ ar1 = dpt .ones (test_shape2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ]
108
+ ar2 = dpt .ones (test_shape2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ]
109
+ r4 = dpt .add (ar1 , ar2 , order = "K" )
110
+ assert r4 .strides == (n , - 1 )
111
+ r5 = dpt .add (ar1 , ar2 , order = "C" )
112
+ assert r5 .strides == (n , 1 )
113
+
114
+ ar1 = dpt .ones (test_shape2 , dtype = dt1 , order = "C" )[:20 , ::- 2 ].mT
115
+ ar2 = dpt .ones (test_shape2 , dtype = dt2 , order = "C" )[:20 , ::- 2 ].mT
116
+ r4 = dpt .add (ar1 , ar2 , order = "K" )
117
+ assert r4 .strides == (- 1 , n )
118
+ r5 = dpt .add (ar1 , ar2 , order = "C" )
119
+ assert r5 .strides == (n , 1 )
108
120
109
121
110
122
def test_add_broadcasting ():
@@ -266,7 +278,7 @@ def test_add_dtype_error(
266
278
skip_if_dtype_not_supported (dtype , q )
267
279
268
280
ar1 = dpt .ones (5 , dtype = dtype )
269
- ar2 = dpt .ones_like (ar1 , dtype = "f8 " )
281
+ ar2 = dpt .ones_like (ar1 , dtype = "f4 " )
270
282
271
283
y = dpt .zeros_like (ar1 , dtype = "int8" )
272
284
assert_raises_regex (
0 commit comments