We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c509d84 commit 2474f1eCopy full SHA for 2474f1e
tritonbench/operators/layer_norm/operator.py
@@ -34,6 +34,14 @@ def torch_layer_norm(self, *args):
34
35
@register_benchmark()
36
def torch_compile_layer_norm(self, *args):
37
+ # We need to run backward multiple times for proper benchmarking
38
+ # so donated buffer have to be disabled
39
+ if self.mode == Mode.BWD or self.mode == Mode.FWD_BWD:
40
+ import torch._functorch.config
41
+
42
+ torch._functorch.config.donated_buffer = False
43
+ import torch
44
45
@torch.compile
46
def inner(*args):
47
return F.layer_norm(*args)
0 commit comments