Skip to content

#sdy Remove MHLO shardings from round-trip export #30091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,17 +463,17 @@ def f(x):
self.assertAllClose(np.asarray(actual.addressable_shards[0].data), expected,
check_dtypes=False)

hlo = f.lower(np.ones(shape)).compiler_ir()
lowered_text = f.lower(np.ones(shape)).as_text()
if config.use_shardy_partitioner.value:
# Annotation from with_sharding_constraint
self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo))
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
# Annotation from pjit
self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo))
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
else:
# Annotation from with_sharding_constraint
self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo))
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
# Annotation from pjit
self.assertIn('sharding = "{replicated}"', str(hlo))
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)

def testShardingConstraintWithArray(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
Expand All @@ -494,11 +494,17 @@ def f(x):
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual, expected, check_dtypes=False)

hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
# Annotation from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
lowered_text = f.lower(np.ones(shape)).as_text()
if config.use_shardy_partitioner.value:
# Annotation from with_sharding_constraint
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
# Annotation from pjit
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
else:
# Annotation from with_sharding_constraint
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
# Annotation from pjit
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)

def testShardingConstraintWithArrayOpSharding(self):
shape = (8, 8)
Expand All @@ -521,11 +527,15 @@ def f(x):
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(actual, expected, check_dtypes=False)

hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
# Annotation from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
lowered_text = f.lower(np.ones(shape)).as_text()
if config.use_shardy_partitioner.value:
# Annotation from pjit
self.assertIn('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}', lowered_text)
else:
# Annotation from with_sharding_constraint
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)
# Annotation from pjit
self.assertIn('mhlo.sharding = "{replicated}"', lowered_text)

def testShardingConstraintPyTreeWithArray(self):
mesh = jtu.create_mesh((2, 1), ('x', 'y'))
Expand All @@ -544,13 +554,15 @@ def f(x):
self.assertLen(out[0].addressable_shards, 2)
self.assertLen(out[1].addressable_shards, 2)

hlo = f.lower(x).compiler_ir(dialect="hlo")
# Annotations from with_sharding_constraint
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
self.assertIn('sharding={devices=[2,1]<=[2]}', hlo.as_hlo_text())
lowered_text = f.lower(x).as_text()
if config.use_shardy_partitioner.value:
# Annotation from with_sharding_constraint
self.assertIn('<@mesh, [{"x"}, {"y"}]>', lowered_text)
else:
# Annotations from with_sharding_constraint
self.assertIn('mhlo.sharding = "{devices=[2,1]<=[2]}"', lowered_text)

def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self):

mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jax.jit
def f(x):
Expand All @@ -571,17 +583,16 @@ def f(x):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]['a'].addressable_shards, 4)

mlir_str = str(f.lower(x).compiler_ir())
lowered_text = f.lower(x).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str)
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str)
self.assertIn('<@mesh, [{?}, {"y"}, {}]>', lowered_text)
self.assertIn('<@mesh, [{"x"}, {?}, {}]>', lowered_text)
else:
self.assertIn("unspecified_dims=[0]", mlir_str)
self.assertIn("unspecified_dims=[1]", mlir_str)
self.assertIn("unspecified_dims=[0]", lowered_text)
self.assertIn("unspecified_dims=[1]", lowered_text)

@jtu.with_mesh([('x', 2), ('y', 2)])
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):

@partial(pjit, in_shardings=None, out_shardings=None)
def f(x):
x = jax.vmap(lambda x: with_sharding_constraint(
Expand All @@ -595,13 +606,13 @@ def f(x):
v = np.arange(math.prod(shape)).reshape(shape)
x = [{'a': v, 'b': v * 2}, v * 3]

mlir_str = str(f.lower(x).compiler_ir())
lowered_text = f.lower(x).as_text()
if config.use_shardy_partitioner.value:
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str)
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str)
self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', lowered_text)
self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', lowered_text)
else:
self.assertIn("unspecified_dims=[0,1]", mlir_str)
self.assertIn("unspecified_dims=[0,2]", mlir_str)
self.assertIn("unspecified_dims=[0,1]", lowered_text)
self.assertIn("unspecified_dims=[0,2]", lowered_text)

def testCaching(self):
def f(x):
Expand Down
Loading