Skip to content

Commit 53364b4

Browse files
gflegarjax authors
authored andcommitted
PiperOrigin-RevId: 614740360
1 parent 71ec6e3 commit 53364b4

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
492492
if consts:
493493
raise NotImplementedError("Associative scan with constants not supported.")
494494
element_types = [_element_type(arg.type) for arg in flat_args]
495-
scan_op = tt_dialect.ScanOp(flat_args, axis)
495+
scan_op = tt_dialect.ScanOp(flat_args, axis, reverse=False)
496496
param_types = element_types * 2
497497
entry = scan_op.regions[0].blocks.append(*param_types)
498498
with ir.InsertionPoint.at_block_begin(entry):

jaxlib/triton/dialect.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import annotations
2020

2121
from collections.abc import Sequence
22+
import inspect
2223

2324
from jaxlib.mlir._mlir_libs._triton_ext import (
2425
PointerType,
@@ -55,12 +56,18 @@ def __init__(
5556
self,
5657
operands: Sequence[ir.Value],
5758
axis: int,
59+
reverse: ir.Attribute,
5860
*,
5961
loc: ir.Location | None = None,
6062
ip: ir.InsertionPoint | None = None,
6163
):
6264
return_types = [op.type for op in operands]
63-
super().__init__(return_types, operands, axis, loc=loc, ip=ip)
65+
# OSS might have an old version of Triton, whose ScanOp doesn't have a
66+
# reverse parameter.
67+
if "reverse" in inspect.signature(super().__init__).parameters:
68+
super().__init__(return_types, operands, axis, reverse, loc=loc, ip=ip)
69+
else:
70+
super().__init__(return_types, operands, axis, loc=loc, ip=ip) # pylint: disable=no-value-for-parameter
6471

6572

6673
# TODO(slebedev): Consider overriding instead.

0 commit comments

Comments
 (0)