Skip to content

Commit 8a11659

Browse files
vwbakerGoogle-ML-Automation
authored andcommitted
Always lower ragged dot for cpu, gpu, and tpu
PiperOrigin-RevId: 780546970
1 parent 719ea96 commit 8a11659

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

jax/_src/lax/lax.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6372,16 +6372,11 @@ def use_default_lowering():
63726372
return [result]
63736373

63746374

6375-
mlir.register_lowering(ragged_dot_general_p,
6376-
mlir.lower_fun(_ragged_dot_general_impl,
6377-
multiple_results=False))
6378-
6379-
for platform in ['tpu']:
6380-
mlir.register_lowering(
6381-
ragged_dot_general_p,
6382-
partial(_ragged_dot_general_lower, platform=platform),
6383-
platform=platform,
6384-
)
6375+
mlir.register_lowering(
6376+
ragged_dot_general_p,
6377+
partial(_ragged_dot_general_lower, platform=platform),
6378+
platform=platform,
6379+
)
63856380

63866381

63876382
def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions,

0 commit comments

Comments
 (0)