Skip to content

Commit 39b47f7

Browse files
#sdy Remove MHLO shardings from round-trip export
* Remove sdy attributes + mesh symbols in SdyRoundTripExportShardyAttrsPass since they are no longer needed for ExportStablehloShardingsPass. * Minor test cleanup to use .as_text() and StableHLO instead of .compiler_ir() and HLO. MHLO shardings no longer needed now that Pathways handles sdy attributes. PiperOrigin-RevId: 775381525
1 parent 5300b4d commit 39b47f7

File tree

1 file changed

+46
-32
lines changed

1 file changed

+46
-32
lines changed

tests/pjit_test.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def f(x, y):
446446
f(x1, x2)
447447

448448
@jtu.with_mesh([('x', 2), ('y', 1)])
449-
def testShardingConstraintStablehlo(self):
449+
def testShardingConstraintMeshContext(self):
450450
@partial(pjit, in_shardings=None, out_shardings=None)
451451
def f(x):
452452
y = x + 1
@@ -463,17 +463,17 @@ def f(x):
463463
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
464464
check_dtypes=False)
465465

466-
hlo = f.lower(np.ones(shape)).compiler_ir()
466+
lowered_text = f.lower(np.ones(shape)).as_text()
467467
if config.use_shardy_partitioner.value:
468468
# Annotation from with_sharding_constraint
469-
self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo))
469+
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
470470
# Annotation from pjit
471-
self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo))
471+
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
472472
else:
473473
# Annotation from with_sharding_constraint
474-
self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo))
474+
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
475475
# Annotation from pjit
476-
self.assertIn('sharding = "{replicated}"', str(hlo))
476+
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)
477477

478478
def testShardingConstraintWithArray(self):
479479
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
@@ -494,11 +494,17 @@ def f(x):
494494
self.assertLen(actual.addressable_shards, 2)
495495
self.assertAllClose(actual, expected, check_dtypes=False)
496496

497-
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
498-
# Annotation from with_sharding_constraint
499-
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
500-
# Annotation from pjit
501-
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
497+
lowered_text = f.lower(np.ones(shape)).as_text()
498+
if config.use_shardy_partitioner.value:
499+
# Annotation from with_sharding_constraint
500+
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
501+
# Annotation from pjit
502+
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
503+
else:
504+
# Annotation from with_sharding_constraint
505+
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
506+
# Annotation from pjit
507+
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)
502508

503509
def testShardingConstraintWithArrayOpSharding(self):
504510
shape = (8, 8)
@@ -521,11 +527,18 @@ def f(x):
521527
self.assertLen(actual.addressable_shards, 2)
522528
self.assertAllClose(actual, expected, check_dtypes=False)
523529

524-
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
525-
# Annotation from with_sharding_constraint
526-
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
527-
# Annotation from pjit
528-
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
530+
lowered_text = f.lower(np.ones(shape)).as_text()
531+
if config.use_shardy_partitioner.value:
532+
# Annotation from with_sharding_constraint, translated from GSPMD to SDY
533+
self.assertIn('@mesh_0 = <["_axis_0"=2]>', lowered_text)
534+
self.assertIn('<@mesh_0, [{"_axis_0"}, {}]>', lowered_text)
535+
# Annotation from pjit
536+
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
537+
else:
538+
# Annotation from with_sharding_constraint
539+
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
540+
# Annotation from pjit
541+
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)
529542

530543
def testShardingConstraintPyTreeWithArray(self):
531544
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
@@ -544,13 +557,15 @@ def f(x):
544557
self.assertLen(out[0].addressable_shards, 2)
545558
self.assertLen(out[1].addressable_shards, 2)
546559

547-
hlo = f.lower(x).compiler_ir(dialect="hlo")
548-
# Annotations from with_sharding_constraint
549-
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
550-
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
560+
lowered_text = f.lower(x).as_text()
561+
if config.use_shardy_partitioner.value:
562+
# Annotation from with_sharding_constraint
563+
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
564+
else:
565+
# Annotations from with_sharding_constraint
566+
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
551567

552568
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):
553-
554569
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
555570
@jax.jit
556571
def f(x):
@@ -571,17 +586,16 @@ def f(x):
571586
self.assertAllClose(actual, expected, check_dtypes=False)
572587
self.assertLen(actual[0]['a'].addressable_shards, 4)
573588

574-
mlir_str = str(f.lower(x).compiler_ir())
589+
lowered_text = f.lower(x).as_text()
575590
if config.use_shardy_partitioner.value:
576-
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str)
577-
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str)
591+
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', lowered_text)
592+
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', lowered_text)
578593
else:
579-
self.assertIn("unspecified_dims=[0]", mlir_str)
580-
self.assertIn("unspecified_dims=[1]", mlir_str)
594+
self.assertIn("unspecified_dims=[0]", lowered_text)
595+
self.assertIn("unspecified_dims=[1]", lowered_text)
581596

582597
@jtu.with_mesh([('x', 2), ('y', 2)])
583598
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):
584-
585599
@partial(pjit, in_shardings=None, out_shardings=None)
586600
def f(x):
587601
x = jax.vmap(lambda x: with_sharding_constraint(
@@ -595,13 +609,13 @@ def f(x):
595609
v = np.arange(math.prod(shape)).reshape(shape)
596610
x = [{'a': v, 'b': v * 2}, v * 3]
597611

598-
mlir_str = str(f.lower(x).compiler_ir())
612+
lowered_text = f.lower(x).as_text()
599613
if config.use_shardy_partitioner.value:
600-
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str)
601-
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str)
614+
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', lowered_text)
615+
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', lowered_text)
602616
else:
603-
self.assertIn("unspecified_dims=[0,1]", mlir_str)
604-
self.assertIn("unspecified_dims=[0,2]", mlir_str)
617+
self.assertIn("unspecified_dims=[0,1]", lowered_text)
618+
self.assertIn("unspecified_dims=[0,2]", lowered_text)
605619

606620
def testCaching(self):
607621
def f(x):

0 commit comments

Comments
 (0)