Skip to content

Commit 778933d

Browse files
superbobryjax authors
authored andcommitted
Removed inspect.signature() call from jaxlib.triton.dialect.ScanOp
PiperOrigin-RevId: 614772594
1 parent 93e5bbe commit 778933d

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes):
492492
if consts:
493493
raise NotImplementedError("Associative scan with constants not supported.")
494494
element_types = [_element_type(arg.type) for arg in flat_args]
495-
scan_op = tt_dialect.ScanOp(flat_args, axis, reverse=False)
495+
scan_op = tt_dialect.ScanOp(flat_args, axis)
496496
param_types = element_types * 2
497497
entry = scan_op.regions[0].blocks.append(*param_types)
498498
with ir.InsertionPoint.at_block_begin(entry):

jaxlib/triton/dialect.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from __future__ import annotations
2020

2121
from collections.abc import Sequence
22-
import inspect
2322

2423
from jaxlib.mlir._mlir_libs._triton_ext import (
2524
PointerType,
@@ -56,18 +55,13 @@ def __init__(
5655
self,
5756
operands: Sequence[ir.Value],
5857
axis: int,
59-
reverse: ir.Attribute,
58+
reverse: bool = False,
6059
*,
6160
loc: ir.Location | None = None,
6261
ip: ir.InsertionPoint | None = None,
6362
):
6463
return_types = [op.type for op in operands]
65-
# OSS might have an old version of Triton, whose ScanOp doesn't have a
66-
# reverse parameter.
67-
if "reverse" in inspect.signature(super().__init__).parameters:
68-
super().__init__(return_types, operands, axis, reverse, loc=loc, ip=ip)
69-
else:
70-
super().__init__(return_types, operands, axis, loc=loc, ip=ip) # pylint: disable=no-value-for-parameter
64+
super().__init__(return_types, operands, axis, reverse, loc=loc, ip=ip)
7165

7266

7367
# TODO(slebedev): Consider overriding instead.

0 commit comments

Comments
 (0)