Skip to content

Commit c509d84

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Fix backward pass on causal
Summary: For the Triton kernel, only causal is implemented on the backward pass (see details at D66574768). FA3 backward implements both causal and non-causal correctly. Reviewed By: manman-ren, sijiac Differential Revision: D66574710 fbshipit-source-id: 5117ea6c25066ca91aeed241e0e7a0dd2ae51b5d
1 parent deb9183 commit c509d84

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,8 @@ def grid_tma(META):
18981898

18991899
@staticmethod
19001900
def backward(ctx, do):
1901+
if not ctx.causal:
1902+
raise NotImplementedError("only causal backward is implemented on Triton")
19011903
q, k, v, o, M = ctx.saved_tensors
19021904
assert do.is_contiguous()
19031905
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()

tritonbench/operators/flash_attention/operator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ def parse_op_args(args: List[str]):
144144
parser.add_argument("--seq-len", type=int, default=11, help="Batch size")
145145
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
146146
parser.add_argument("--d-head", type=int, default=64, help="specify head dimension")
147-
parser.add_argument("--causal", action="store_true", help="enable causal")
147+
parser.add_argument(
148+
"--causal",
149+
action="store_true",
150+
help="enable causal (always true on backward)",
151+
)
148152
return parser.parse_args(args)
149153

150154

@@ -164,6 +168,10 @@ def __init__(
164168
self.D_HEAD = args.d_head
165169
self.N_CTX = None
166170
self.causal = args.causal
171+
# We always turn on causal for backward
172+
# Because Triton-Flash-V2 does not support backward with non-causal
173+
if self.mode == BenchmarkMode.BWD or self.mode == BenchmarkMode.FWD_BWD:
174+
self.causal = True
167175
self.sm_scale = 1.3
168176

169177
@register_benchmark()

0 commit comments

Comments
 (0)