Skip to content

Commit 4589b94

Browse files
authored
[Bugfix] Fix benchmark_moe.py (#19016)
Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
1 parent cc867be commit 4589b94

File tree

1 file changed

+4
-10
lines changed

1 file changed

+4
-10
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from contextlib import nullcontext
88
from datetime import datetime
99
from itertools import product
10-
from types import SimpleNamespace
1110
from typing import Any, TypedDict
1211

1312
import ray
@@ -43,7 +42,7 @@ def benchmark_config(
4342
use_fp8_w8a8: bool,
4443
use_int8_w8a16: bool,
4544
num_iters: int = 100,
46-
block_quant_shape: List[int] = None,
45+
block_quant_shape: list[int] = None,
4746
use_deep_gemm: bool = False,
4847
) -> float:
4948
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
@@ -400,7 +399,7 @@ def benchmark(
400399
dtype: torch.dtype,
401400
use_fp8_w8a8: bool,
402401
use_int8_w8a16: bool,
403-
block_quant_shape: List[int] = None,
402+
block_quant_shape: list[int] = None,
404403
use_deep_gemm: bool = False,
405404
) -> tuple[dict[str, int], float]:
406405
current_platform.seed_everything(self.seed)
@@ -532,7 +531,7 @@ def save_configs(
532531
dtype: torch.dtype,
533532
use_fp8_w8a8: bool,
534533
use_int8_w8a16: bool,
535-
block_quant_shape: List[int],
534+
block_quant_shape: list[int],
536535
) -> None:
537536
dtype_str = get_config_dtype_str(
538537
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
563562
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
564563
if args.model_prefix:
565564
config = getattr(config, args.model_prefix)
566-
config = SimpleNamespace(**config)
567565

568566
if config.architectures[0] == "DbrxForCausalLM":
569567
E = config.ffn_config.moe_num_experts
@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
595593
shard_intermediate_size = 2 * intermediate_size // args.tp_size
596594

597595
hidden_size = config.hidden_size
598-
dtype = (
599-
torch.float16
600-
if current_platform.is_rocm()
601-
else getattr(torch, config.torch_dtype)
602-
)
596+
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
603597
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
604598
use_int8_w8a16 = args.dtype == "int8_w8a16"
605599
block_quant_shape = get_weight_block_size_safety(config)

0 commit comments

Comments
 (0)