Skip to content

Commit b107be2

Browse files
author
jax authors
committed
[Pallas TPU] Add missing Mosaic lowering rules for float comparisons, tests.
PiperOrigin-RevId: 623876922
1 parent 36bedee commit b107be2

File tree

2 files changed

+140
-1
lines changed

2 files changed

+140
-1
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,8 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
15481548
lowering_rules[lax.round_p] = _round_lowering_rule
15491549

15501550

1551+
# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpi-arithcmpiop for
1552+
# the mapping from comparison type to integer predicates for int comparisons.
15511553
_cmpi_lowering_types = {
15521554
lax.eq_p: 0,
15531555
lax.ne_p: 1,
@@ -1557,10 +1559,15 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
15571559
lax.ge_p: 5,
15581560
}
15591561

1562+
# See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithcmpf-arithcmpfop for
1563+
# the mapping from comparison type to integer predicate for float comparisons.
15601564
_cmpf_lowering_types = {
15611565
lax.eq_p: 1,
1562-
lax.gt_p: 2,
15631566
lax.ne_p: 6,
1567+
lax.lt_p: 4,
1568+
lax.le_p: 5,
1569+
lax.gt_p: 2,
1570+
lax.ge_p: 3,
15641571
}
15651572

15661573

tests/pallas/pallas_call_tpu_test.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,5 +2254,137 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
22542254
np.testing.assert_array_equal(out, expected)
22552255

22562256

2257+
class PallasCallComparisonTest(PallasTPUTest):
2258+
2259+
def setUp(self):
2260+
super().setUp()
2261+
if jtu.device_under_test() != 'tpu':
2262+
self.skipTest('Test only works on TPU')
2263+
2264+
@parameterized.named_parameters(
2265+
('integer_1_1', (1, 1)),
2266+
('integer_1_16', (1, 16)),
2267+
('integer_16_1', (16, 1)),
2268+
('integer_-1_1', (-1, 1)),
2269+
('integer_1_-1', (1, -1)),
2270+
('float_1_1', (1.0, 1.0)),
2271+
('float_1_16', (1.0, 16.0)),
2272+
('float_16_1', (16.0, 1.0)),
2273+
('float_-1_1', (-1.0, 1.0)),
2274+
('float_1_-1', (1.0, -1.0)),
2275+
('float_1_inf', (1.0, float('inf'))),
2276+
('float_inf_1', (float('inf'), 1.0)),
2277+
('float_inf_inf', (float('inf'), float('inf'))),
2278+
('float_1_nan', (1.0, float('nan'))),
2279+
('float_nan_1', (float('nan'), 1.0)),
2280+
('float_nan_nan', (float('nan'), float('nan'))),
2281+
('float_inf_nan', (float('inf'), float('nan'))),
2282+
('float_nan_inf', (float('inf'), float('inf'))),
2283+
)
2284+
def test_scalar_compare(self, params):
2285+
"""Test some scalar compares.
2286+
2287+
We don't really expect that the results would be wrong, but rather we want
2288+
to exercise the lowering rules.
2289+
"""
2290+
2291+
def kernel(x_ref, y_ref, o_ref):
2292+
x = x_ref[0, 0]
2293+
y = y_ref[0, 0]
2294+
o_ref[0, 0] = jax.lax.select(x == y, 1, 0)
2295+
o_ref[0, 1] = jax.lax.select(x != y, 1, 0)
2296+
o_ref[0, 2] = jax.lax.select(x < y, 1, 0)
2297+
o_ref[0, 3] = jax.lax.select(x <= y, 1, 0)
2298+
o_ref[0, 4] = jax.lax.select(x > y, 1, 0)
2299+
o_ref[0, 5] = jax.lax.select(x >= y, 1, 0)
2300+
2301+
x, y = params
2302+
r = jnp.array(
2303+
[
2304+
[x == y, x != y, x < y, x <= y, x > y, x >= y],
2305+
],
2306+
jnp.int32,
2307+
)
2308+
x = jnp.array([[x]])
2309+
y = jnp.array([[y]])
2310+
2311+
result = pl.pallas_call(
2312+
kernel,
2313+
out_shape=jax.ShapeDtypeStruct([1, 128], jnp.int32),
2314+
in_specs=[
2315+
pl.BlockSpec(memory_space=pltpu.SMEM),
2316+
pl.BlockSpec(memory_space=pltpu.SMEM),
2317+
],
2318+
out_specs=pl.BlockSpec(
2319+
lambda i: (0, 0), (1, 128), memory_space=pltpu.SMEM
2320+
),
2321+
grid=(1,),
2322+
)(x, y)
2323+
np.testing.assert_array_equal(r, result[..., 0:6])
2324+
2325+
@parameterized.named_parameters(
2326+
('integer_1_1', (1, 1)),
2327+
('integer_1_16', (1, 16)),
2328+
('integer_16_1', (16, 1)),
2329+
('integer_-1_1', (-1, 1)),
2330+
('integer_1_-1', (1, -1)),
2331+
('float_1_1', (1.0, 1.0)),
2332+
('float_1_16', (1.0, 16.0)),
2333+
('float_16_1', (16.0, 1.0)),
2334+
('float_-1_1', (-1.0, 1.0)),
2335+
('float_1_-1', (1.0, -1.0)),
2336+
('float_1_inf', (1.0, float('inf'))),
2337+
('float_inf_1', (float('inf'), 1.0)),
2338+
('float_inf_inf', (float('inf'), float('inf'))),
2339+
('float_1_nan', (1.0, float('nan'))),
2340+
('float_nan_1', (float('nan'), 1.0)),
2341+
('float_nan_nan', (float('nan'), float('nan'))),
2342+
('float_inf_nan', (float('inf'), float('nan'))),
2343+
('float_nan_inf', (float('inf'), float('inf'))),
2344+
)
2345+
def test_vector_compare(self, params):
2346+
"""Test some vector compares.
2347+
2348+
We don't really expect that the results would be wrong, but rather we want
2349+
to exercise the lowering rules.
2350+
"""
2351+
2352+
def kernel(x_ref, y_ref, o_ref):
2353+
x = x_ref[:]
2354+
y = y_ref[:]
2355+
one = jnp.ones([8, 128], dtype=jnp.int32)
2356+
zero = jnp.zeros([8, 128], dtype=jnp.int32)
2357+
o_ref[0] = jax.lax.select(x == y, one, zero)
2358+
o_ref[1] = jax.lax.select(x != y, one, zero)
2359+
o_ref[2] = jax.lax.select(x < y, one, zero)
2360+
o_ref[3] = jax.lax.select(x <= y, one, zero)
2361+
o_ref[4] = jax.lax.select(x > y, one, zero)
2362+
o_ref[5] = jax.lax.select(x >= y, one, zero)
2363+
2364+
# Widen out our params to (8, 128) vectors.
2365+
x, y = params
2366+
x = jnp.full([8, 128], x)
2367+
y = jnp.full([8, 128], y)
2368+
2369+
r = [x == y, x != y, x < y, x <= y, x > y, x >= y]
2370+
2371+
result = pl.pallas_call(
2372+
kernel,
2373+
out_shape=jax.ShapeDtypeStruct([6, 8, 128], jnp.int32),
2374+
in_specs=[
2375+
pl.BlockSpec(lambda *_: (0, 0), (8, 128)),
2376+
pl.BlockSpec(lambda *_: (0, 0), (8, 128)),
2377+
],
2378+
out_specs=pl.BlockSpec(lambda *_: (0, 0, 0), (6, 8, 128)),
2379+
grid=(1,),
2380+
)(x, y)
2381+
np.testing.assert_array_equal(r[0], result[0])
2382+
np.testing.assert_array_equal(r[1], result[1])
2383+
np.testing.assert_array_equal(r[2], result[2])
2384+
np.testing.assert_array_equal(r[3], result[3])
2385+
np.testing.assert_array_equal(r[4], result[4])
2386+
np.testing.assert_array_equal(r[5], result[5])
2387+
2388+
22572389
if __name__ == '__main__':
22582390
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)