@@ -76,19 +76,30 @@ def unary_assert_against_refimpl(
76
76
in_stype = dh .get_scalar_type (in_ .dtype )
77
77
if res_stype is None :
78
78
res_stype = in_stype
79
+ if res .dtype != xp .bool :
80
+ m , M = dh .dtype_ranges [res .dtype ]
79
81
for idx in sh .ndindex (in_ .shape ):
80
82
scalar_i = in_stype (in_ [idx ])
81
83
if not filter_ (scalar_i ):
82
84
continue
83
85
expected = refimpl (scalar_i )
86
+ if res .dtype != xp .bool :
87
+ if expected <= m or expected >= M :
88
+ continue
84
89
scalar_o = res_stype (res [idx ])
85
90
f_i = sh .fmt_idx ("x" , idx )
86
91
f_o = sh .fmt_idx ("out" , idx )
87
92
expr = expr_template .format (f_i , expected )
88
- assert scalar_o == expected , (
89
- f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
90
- f"{ f_i } ={ scalar_i } "
91
- )
93
+ if dh .is_float_dtype (res .dtype ):
94
+ assert isclose (scalar_o , expected ), (
95
+ f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
96
+ f"{ f_i } ={ scalar_i } "
97
+ )
98
+ else :
99
+ assert scalar_o == expected , (
100
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
101
+ f"{ f_i } ={ scalar_i } "
102
+ )
92
103
93
104
94
105
def binary_assert_against_refimpl (
@@ -1257,29 +1268,35 @@ def test_sin(x):
1257
1268
out = xp .sin (x )
1258
1269
ph .assert_dtype ("sin" , x .dtype , out .dtype )
1259
1270
ph .assert_shape ("sin" , out .shape , x .shape )
1260
- # TODO
1271
+ unary_assert_against_refimpl ( "sin" , x , out , math . sin , "sin({})={}" )
1261
1272
1262
1273
1263
1274
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
1264
1275
def test_sinh (x ):
1265
1276
out = xp .sinh (x )
1266
1277
ph .assert_dtype ("sinh" , x .dtype , out .dtype )
1267
1278
ph .assert_shape ("sinh" , out .shape , x .shape )
1268
- # TODO
1279
+ unary_assert_against_refimpl ( "sinh" , x , out , math . sinh , "sinh({})={}" )
1269
1280
1270
1281
1271
1282
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
1272
1283
def test_square (x ):
1273
1284
out = xp .square (x )
1274
1285
ph .assert_dtype ("square" , x .dtype , out .dtype )
1275
1286
ph .assert_shape ("square" , out .shape , x .shape )
1287
+ unary_assert_against_refimpl ("square" , x , out , lambda s : s ** 2 , "{}²={}" )
1276
1288
1277
1289
1278
- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
1290
+ @given (
1291
+ xps .arrays (
1292
+ dtype = xps .floating_dtypes (), shape = hh .shapes (), elements = {"min_value" : 0 }
1293
+ )
1294
+ )
1279
1295
def test_sqrt (x ):
1280
1296
out = xp .sqrt (x )
1281
1297
ph .assert_dtype ("sqrt" , x .dtype , out .dtype )
1282
1298
ph .assert_shape ("sqrt" , out .shape , x .shape )
1299
+ unary_assert_against_refimpl ("sqrt" , x , out , math .sqrt , "sqrt({})={}" )
1283
1300
1284
1301
1285
1302
@pytest .mark .parametrize ("ctx" , make_binary_params ("subtract" , xps .numeric_dtypes ()))
@@ -1305,15 +1322,15 @@ def test_tan(x):
1305
1322
out = xp .tan (x )
1306
1323
ph .assert_dtype ("tan" , x .dtype , out .dtype )
1307
1324
ph .assert_shape ("tan" , out .shape , x .shape )
1308
- # TODO
1325
+ unary_assert_against_refimpl ( "tan" , x , out , math . tan , "tan({})={}" )
1309
1326
1310
1327
1311
1328
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
1312
1329
def test_tanh (x ):
1313
1330
out = xp .tanh (x )
1314
1331
ph .assert_dtype ("tanh" , x .dtype , out .dtype )
1315
1332
ph .assert_shape ("tanh" , out .shape , x .shape )
1316
- # TODO
1333
+ unary_assert_against_refimpl ( "tanh" , x , out , math . tanh , "tanh({})={}" )
1317
1334
1318
1335
1319
1336
@given (xps .arrays (dtype = hh .numeric_dtypes , shape = xps .array_shapes ()))
0 commit comments