Skip to content

Commit 0f1cb74

Browse files
naummojax authors
authored andcommitted
Prevent the XLA compiler from sharding the custom call in favour of Mosaic sharding based on user annotations.
PiperOrigin-RevId: 614336455
1 parent 632d095 commit 0f1cb74

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def to_json(self) -> bytes:
141141
if i + 1 != len(self.flags):
142142
config.write(b",")
143143
config.write(b"]")
144+
# Prevent the compiler from sharding the custom call beyond what Mosaic does
145+
# based on user annotations
146+
config.write(b', "implicit_sharding": {"type": "MANUAL"}')
144147
config.write(b"}")
145148
return config.getvalue()
146149

0 commit comments

Comments
 (0)