@@ -1517,6 +1517,7 @@ def add_one(x_ref, o_ref):
1517
1517
class PallasCallInterpreterVmapTest (PallasCallVmapTest ):
1518
1518
INTERPRET = True
1519
1519
1520
+
1520
1521
class PallasOpsTest (PallasTest ):
1521
1522
1522
1523
def test_pow_weak_dtype (self ):
@@ -1528,17 +1529,31 @@ def square(x_ref, o_ref):
1528
1529
x = jnp .array (42.0 )
1529
1530
np .testing .assert_allclose (square (x ), x * x )
1530
1531
1531
- def test_ne (self ):
1532
+ COMPARISON_OPS = [
1533
+ jnp .equal ,
1534
+ jnp .not_equal ,
1535
+ jnp .less ,
1536
+ jnp .less_equal ,
1537
+ jnp .greater ,
1538
+ jnp .greater_equal ,
1539
+ ]
1540
+
1541
+ @parameterized .named_parameters (
1542
+ (f"{ fn .__name__ } _{ dtype } " , fn , dtype )
1543
+ for fn , dtype in itertools .product (
1544
+ COMPARISON_OPS , ["int32" , "uint32" , "float32" ]
1545
+ )
1546
+ )
1547
+ def test_comparison (self , fn , dtype ):
1532
1548
@functools .partial (
1533
1549
self .pallas_call , out_shape = jax .ShapeDtypeStruct ((8 ,), jnp .bool_ ),
1534
1550
grid = 1 )
1535
- def ne (x_ref , y_ref , o_ref ):
1536
- o_ref [:] = x_ref [...] != y_ref [...]
1551
+ def kernel (x_ref , y_ref , o_ref ):
1552
+ o_ref [:] = fn ( x_ref [...], y_ref [...])
1537
1553
1538
- x = jnp .ones (8 , dtype = jnp .int32 )
1539
- y = jnp .arange (8 , dtype = jnp .int32 )
1540
- not_equal = ne (x , y )
1541
- np .testing .assert_allclose (not_equal , x != y )
1554
+ x = jnp .array ([1 , 3 , - 4 , - 6 , 2 , 5 , 4 , - 7 ]).astype (dtype )
1555
+ y = jnp .array ([3 , 1 , - 4 , - 5 , 2 , - 2 , 2 , 4 ]).astype (dtype )
1556
+ np .testing .assert_allclose (kernel (x , y ), fn (x , y ))
1542
1557
1543
1558
def test_isnan (self ):
1544
1559
@functools .partial (
@@ -1551,33 +1566,52 @@ def isnan(x_ref, o_ref):
1551
1566
x = x .at [3 ].set (jnp .nan )
1552
1567
np .testing .assert_allclose (isnan (x ), jnp .isnan (x ))
1553
1568
1554
- @parameterized .named_parameters (* (
1555
- (fn .__name__ , fn , out_dtype )
1556
- for fn , out_dtype in [
1557
- (jnp .add , jnp .int32 ),
1558
- (jnp .subtract , jnp .int32 ),
1559
- (jnp .multiply , jnp .int32 ),
1560
- (jnp .true_divide , jnp .float32 ),
1561
- (jnp .remainder , jnp .int32 ),
1562
- (jnp .less , jnp .bool_ ),
1563
- (jnp .less_equal , jnp .bool_ ),
1564
- (jnp .greater , jnp .bool_ ),
1565
- (jnp .greater_equal , jnp .bool_ ),
1566
- (jnp .equal , jnp .bool_ ),
1567
- (jnp .not_equal , jnp .bool_ ),
1568
- ]
1569
- ))
1570
- def test_signed_int_ops (self , f , out_dtype ):
1569
+ def test_true_divide (self ):
1571
1570
@functools .partial (
1572
1571
self .pallas_call ,
1573
- out_shape = jax .ShapeDtypeStruct ((8 ,), out_dtype ),
1574
- grid = 1 )
1575
- def f_i32 (x_ref , y_ref , o_ref ):
1572
+ out_shape = jax .ShapeDtypeStruct ((8 ,), jnp .float32 ),
1573
+ grid = 1 ,
1574
+ )
1575
+ def kernel (x_ref , y_ref , o_ref ):
1576
+ o_ref [...] = jnp .true_divide (x_ref [...], y_ref [...])
1577
+
1578
+ x = jnp .array ([1 , 3 , - 4 , - 6 , 2 , 5 , 4 , - 7 ], dtype = jnp .int32 )
1579
+ y = jnp .array ([3 , 1 , - 4 , - 5 , 2 , - 2 , 2 , 4 ], dtype = jnp .int32 )
1580
+ np .testing .assert_allclose (jnp .true_divide (x , y ), kernel (x , y ))
1581
+
1582
+ BINARY_OPS = [
1583
+ ([jnp .floor_divide ], ["int32" , "uint32" ]),
1584
+ (
1585
+ [jnp .add , jnp .subtract , jnp .multiply , jnp .remainder ],
1586
+ ["int32" , "uint32" , "float32" ],
1587
+ ),
1588
+ (
1589
+ [
1590
+ jnp .bitwise_and ,
1591
+ jnp .bitwise_or ,
1592
+ jnp .bitwise_xor ,
1593
+ jnp .bitwise_left_shift ,
1594
+ jnp .bitwise_right_shift ,
1595
+ ],
1596
+ ["int32" , "uint32" ],
1597
+ ),
1598
+ ]
1599
+
1600
+ @parameterized .named_parameters (
1601
+ (f"{ fn .__name__ } _{ dtype } " , fn , dtype )
1602
+ for args in BINARY_OPS
1603
+ for fn , dtype in itertools .product (* args )
1604
+ )
1605
+ def test_binary (self , f , dtype ):
1606
+ @functools .partial (
1607
+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((8 ,), dtype ), grid = 1
1608
+ )
1609
+ def kernel (x_ref , y_ref , o_ref ):
1576
1610
o_ref [...] = f (x_ref [...], y_ref [...])
1577
1611
1578
- x = jnp .int32 ([1 , 3 , - 4 , - 6 , 2 , 5 , 4 , - 7 ])
1579
- y = jnp .int32 ([3 , 1 , - 4 , - 5 , 2 , - 2 , 0 , 4 ])
1580
- np .testing .assert_allclose (f (x , y ), f_i32 (x , y ))
1612
+ x = jnp .array ([1 , 3 , - 4 , - 6 , 2 , 5 , 4 , - 7 ]). astype ( dtype )
1613
+ y = jnp .array ([3 , 1 , - 4 , - 5 , 2 , - 2 , 2 , 4 ]). astype ( dtype )
1614
+ np .testing .assert_allclose (f (x , y ), kernel (x , y ))
1581
1615
1582
1616
1583
1617
class PallasOpsInterpretTest (PallasOpsTest ):
0 commit comments