@@ -2254,5 +2254,137 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
2254
2254
np .testing .assert_array_equal (out , expected )
2255
2255
2256
2256
2257
+ class PallasCallComparisonTest (PallasTPUTest ):
2258
+
2259
+ def setUp (self ):
2260
+ super ().setUp ()
2261
+ if jtu .device_under_test () != 'tpu' :
2262
+ self .skipTest ('Test only works on TPU' )
2263
+
2264
+ @parameterized .named_parameters (
2265
+ ('integer_1_1' , (1 , 1 )),
2266
+ ('integer_1_16' , (1 , 16 )),
2267
+ ('integer_16_1' , (16 , 1 )),
2268
+ ('integer_-1_1' , (- 1 , 1 )),
2269
+ ('integer_1_-1' , (1 , - 1 )),
2270
+ ('float_1_1' , (1.0 , 1.0 )),
2271
+ ('float_1_16' , (1.0 , 16.0 )),
2272
+ ('float_16_1' , (16.0 , 1.0 )),
2273
+ ('float_-1_1' , (- 1.0 , 1.0 )),
2274
+ ('float_1_-1' , (1.0 , - 1.0 )),
2275
+ ('float_1_inf' , (1.0 , float ('inf' ))),
2276
+ ('float_inf_1' , (float ('inf' ), 1.0 )),
2277
+ ('float_inf_inf' , (float ('inf' ), float ('inf' ))),
2278
+ ('float_1_nan' , (1.0 , float ('nan' ))),
2279
+ ('float_nan_1' , (float ('nan' ), 1.0 )),
2280
+ ('float_nan_nan' , (float ('nan' ), float ('nan' ))),
2281
+ ('float_inf_nan' , (float ('inf' ), float ('nan' ))),
2282
+ ('float_nan_inf' , (float ('inf' ), float ('inf' ))),
2283
+ )
2284
+ def test_scalar_compare (self , params ):
2285
+ """Test some scalar compares.
2286
+
2287
+ We don't really expect that the results would be wrong, but rather we want
2288
+ to exercise the lowering rules.
2289
+ """
2290
+
2291
+ def kernel (x_ref , y_ref , o_ref ):
2292
+ x = x_ref [0 , 0 ]
2293
+ y = y_ref [0 , 0 ]
2294
+ o_ref [0 , 0 ] = jax .lax .select (x == y , 1 , 0 )
2295
+ o_ref [0 , 1 ] = jax .lax .select (x != y , 1 , 0 )
2296
+ o_ref [0 , 2 ] = jax .lax .select (x < y , 1 , 0 )
2297
+ o_ref [0 , 3 ] = jax .lax .select (x <= y , 1 , 0 )
2298
+ o_ref [0 , 4 ] = jax .lax .select (x > y , 1 , 0 )
2299
+ o_ref [0 , 5 ] = jax .lax .select (x >= y , 1 , 0 )
2300
+
2301
+ x , y = params
2302
+ r = jnp .array (
2303
+ [
2304
+ [x == y , x != y , x < y , x <= y , x > y , x >= y ],
2305
+ ],
2306
+ jnp .int32 ,
2307
+ )
2308
+ x = jnp .array ([[x ]])
2309
+ y = jnp .array ([[y ]])
2310
+
2311
+ result = pl .pallas_call (
2312
+ kernel ,
2313
+ out_shape = jax .ShapeDtypeStruct ([1 , 128 ], jnp .int32 ),
2314
+ in_specs = [
2315
+ pl .BlockSpec (memory_space = pltpu .SMEM ),
2316
+ pl .BlockSpec (memory_space = pltpu .SMEM ),
2317
+ ],
2318
+ out_specs = pl .BlockSpec (
2319
+ lambda i : (0 , 0 ), (1 , 128 ), memory_space = pltpu .SMEM
2320
+ ),
2321
+ grid = (1 ,),
2322
+ )(x , y )
2323
+ np .testing .assert_array_equal (r , result [..., 0 :6 ])
2324
+
2325
+ @parameterized .named_parameters (
2326
+ ('integer_1_1' , (1 , 1 )),
2327
+ ('integer_1_16' , (1 , 16 )),
2328
+ ('integer_16_1' , (16 , 1 )),
2329
+ ('integer_-1_1' , (- 1 , 1 )),
2330
+ ('integer_1_-1' , (1 , - 1 )),
2331
+ ('float_1_1' , (1.0 , 1.0 )),
2332
+ ('float_1_16' , (1.0 , 16.0 )),
2333
+ ('float_16_1' , (16.0 , 1.0 )),
2334
+ ('float_-1_1' , (- 1.0 , 1.0 )),
2335
+ ('float_1_-1' , (1.0 , - 1.0 )),
2336
+ ('float_1_inf' , (1.0 , float ('inf' ))),
2337
+ ('float_inf_1' , (float ('inf' ), 1.0 )),
2338
+ ('float_inf_inf' , (float ('inf' ), float ('inf' ))),
2339
+ ('float_1_nan' , (1.0 , float ('nan' ))),
2340
+ ('float_nan_1' , (float ('nan' ), 1.0 )),
2341
+ ('float_nan_nan' , (float ('nan' ), float ('nan' ))),
2342
+ ('float_inf_nan' , (float ('inf' ), float ('nan' ))),
2343
+ ('float_nan_inf' , (float ('inf' ), float ('inf' ))),
2344
+ )
2345
+ def test_vector_compare (self , params ):
2346
+ """Test some vector compares.
2347
+
2348
+ We don't really expect that the results would be wrong, but rather we want
2349
+ to exercise the lowering rules.
2350
+ """
2351
+
2352
+ def kernel (x_ref , y_ref , o_ref ):
2353
+ x = x_ref [:]
2354
+ y = y_ref [:]
2355
+ one = jnp .ones ([8 , 128 ], dtype = jnp .int32 )
2356
+ zero = jnp .zeros ([8 , 128 ], dtype = jnp .int32 )
2357
+ o_ref [0 ] = jax .lax .select (x == y , one , zero )
2358
+ o_ref [1 ] = jax .lax .select (x != y , one , zero )
2359
+ o_ref [2 ] = jax .lax .select (x < y , one , zero )
2360
+ o_ref [3 ] = jax .lax .select (x <= y , one , zero )
2361
+ o_ref [4 ] = jax .lax .select (x > y , one , zero )
2362
+ o_ref [5 ] = jax .lax .select (x >= y , one , zero )
2363
+
2364
+ # Widen out our params to (8, 128) vectors.
2365
+ x , y = params
2366
+ x = jnp .full ([8 , 128 ], x )
2367
+ y = jnp .full ([8 , 128 ], y )
2368
+
2369
+ r = [x == y , x != y , x < y , x <= y , x > y , x >= y ]
2370
+
2371
+ result = pl .pallas_call (
2372
+ kernel ,
2373
+ out_shape = jax .ShapeDtypeStruct ([6 , 8 , 128 ], jnp .int32 ),
2374
+ in_specs = [
2375
+ pl .BlockSpec (lambda * _ : (0 , 0 ), (8 , 128 )),
2376
+ pl .BlockSpec (lambda * _ : (0 , 0 ), (8 , 128 )),
2377
+ ],
2378
+ out_specs = pl .BlockSpec (lambda * _ : (0 , 0 , 0 ), (6 , 8 , 128 )),
2379
+ grid = (1 ,),
2380
+ )(x , y )
2381
+ np .testing .assert_array_equal (r [0 ], result [0 ])
2382
+ np .testing .assert_array_equal (r [1 ], result [1 ])
2383
+ np .testing .assert_array_equal (r [2 ], result [2 ])
2384
+ np .testing .assert_array_equal (r [3 ], result [3 ])
2385
+ np .testing .assert_array_equal (r [4 ], result [4 ])
2386
+ np .testing .assert_array_equal (r [5 ], result [5 ])
2387
+
2388
+
2257
2389
if __name__ == '__main__' :
2258
2390
absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments