Skip to content

Commit c666f87

Browse files
adamomainzfacebook-github-bot
authored andcommitted
adding param to shuffle production data shapes
Summary: TSIA right now just doing this for production shapes Noticed that we sometimes error out at times and do not run all the shapes. since we run multiple times a day randomly shuffling the shapes and aggregating over the day will produce a more stable output Reviewed By: danzimm, xuzhao9 Differential Revision: D66519495 fbshipit-source-id: 56993a8bb196174e05c4224bd386190c18883603
1 parent 0ca9f40 commit c666f87

File tree

4 files changed

+12
-3
lines changed

4 files changed

+12
-3
lines changed

tritonbench/operators/flash_attention/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def __additional_example_input(self, standard_shapes: Generator) -> Generator:
507507
shapes = chain(
508508
shapes,
509509
productionDataLoader.get_shapes_from_frozen_durin(
510-
self.name, "attention"
510+
self.name, "attention", shuffle_shapes=self.tb_args.shuffle_shapes
511511
),
512512
)
513513
return shapes

tritonbench/operators/gemm/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def __init__(
144144
gemm_args = parse_args(self.extra_args)
145145
self.layout = gemm_args.layout
146146
if IS_FBCODE and tb_args.production_shapes:
147-
self.shapes = get_production_shapes(self.name, f"{tb_args.precision}_gemm")
147+
self.shapes = get_production_shapes(
148+
self.name, f"{tb_args.precision}_gemm", self.tb_args.shuffle_shapes
149+
)
148150
elif gemm_args.input:
149151
self.shapes = read_shapes_from_csv(gemm_args.input)
150152
elif gemm_args.splitk:

tritonbench/operators/softmax/operator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def get_input_iter(self):
107107
M = 4096
108108
shapes = [(M, 128 * i) for i in range(2, 100)]
109109
if IS_FBCODE and self.tb_args.production_shapes:
110-
shapes = get_production_shapes(self.name, "softmax")
110+
shapes = get_production_shapes(
111+
self.name, "softmax", self.tb_args.shuffle_shapes
112+
)
111113
for M, N in shapes:
112114
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)
113115

tritonbench/utils/parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def get_parser(args=None):
165165
action="store_true",
166166
help="bypass and continue on operator failure.",
167167
)
168+
parser.add_argument(
169+
"--shuffle-shapes",
170+
action="store_true",
171+
help="when true randomly shuffles the inputs before running benchmarks where possible.",
172+
)
168173

169174
if IS_FBCODE:
170175
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")

0 commit comments

Comments
 (0)