Skip to content

Commit 5a19663

Browse files
yf225facebook-github-bot
authored andcommitted
Allow customizing inputs for cross_entropy benchmark (#281)
Summary: Stacked PRs: * #285 * #284 * #283 * #282 * __->__#281 --- --- --- ### Allow customizing inputs for cross_entropy benchmark Pull Request resolved: #281 Reviewed By: oulgen Differential Revision: D78230194 Pulled By: yf225 fbshipit-source-id: 4717a97a440e5f78245e2da1da8609583b137012
1 parent e3f6db6 commit 5a19663

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

tritonbench/operators/cross_entropy/operator.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,34 @@
2020
# blob/main/benchmark/scripts/benchmark_cross_entropy.py
2121

2222

23+
def parse_op_args(args: List[str]):
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument("--B", type=int, default=8, help="Batch size")
26+
parser.add_argument("--T", type=int, default=2048, help="Sequence length")
27+
parser.add_argument(
28+
"--v-range",
29+
type=str,
30+
default="12,18",
31+
help="Vocabulary size range as 'start,end' (e.g., '10,15' for 2^10 to 2^14)",
32+
)
33+
return parser.parse_args(args)
34+
35+
2336
class Operator(BenchmarkOperator):
2437
def __init__(
2538
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
2639
):
2740
super().__init__(tb_args, extra_args)
28-
self.B = 8
29-
self.T = 2048
41+
args = parse_op_args(self.extra_args)
42+
self.B = args.B
43+
self.T = args.T
44+
start, end = map(int, args.v_range.split(","))
45+
self.v_range = range(start, end)
3046
self.baseline_model = CrossEntropyLoss()
3147
self.liger_model = LigerCrossEntropyLoss()
3248

3349
def get_input_iter(self) -> Generator:
34-
for V in [2**i for i in range(12, 18)]:
50+
for V in [2**i for i in self.v_range]:
3551
_input = torch.randn(
3652
self.B * self.T,
3753
V,

0 commit comments

Comments
 (0)