|
7 | 7 | from contextlib import nullcontext
|
8 | 8 | from datetime import datetime
|
9 | 9 | from itertools import product
|
10 |
| -from types import SimpleNamespace |
11 | 10 | from typing import Any, TypedDict
|
12 | 11 |
|
13 | 12 | import ray
|
@@ -43,7 +42,7 @@ def benchmark_config(
|
43 | 42 | use_fp8_w8a8: bool,
|
44 | 43 | use_int8_w8a16: bool,
|
45 | 44 | num_iters: int = 100,
|
46 |
| - block_quant_shape: List[int] = None, |
| 45 | + block_quant_shape: list[int] = None, |
47 | 46 | use_deep_gemm: bool = False,
|
48 | 47 | ) -> float:
|
49 | 48 | init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
@@ -400,7 +399,7 @@ def benchmark(
|
400 | 399 | dtype: torch.dtype,
|
401 | 400 | use_fp8_w8a8: bool,
|
402 | 401 | use_int8_w8a16: bool,
|
403 |
| - block_quant_shape: List[int] = None, |
| 402 | + block_quant_shape: list[int] = None, |
404 | 403 | use_deep_gemm: bool = False,
|
405 | 404 | ) -> tuple[dict[str, int], float]:
|
406 | 405 | current_platform.seed_everything(self.seed)
|
@@ -532,7 +531,7 @@ def save_configs(
|
532 | 531 | dtype: torch.dtype,
|
533 | 532 | use_fp8_w8a8: bool,
|
534 | 533 | use_int8_w8a16: bool,
|
535 |
| - block_quant_shape: List[int], |
| 534 | + block_quant_shape: list[int], |
536 | 535 | ) -> None:
|
537 | 536 | dtype_str = get_config_dtype_str(
|
538 | 537 | dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
@@ -563,7 +562,6 @@ def main(args: argparse.Namespace):
|
563 | 562 | config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
564 | 563 | if args.model_prefix:
|
565 | 564 | config = getattr(config, args.model_prefix)
|
566 |
| - config = SimpleNamespace(**config) |
567 | 565 |
|
568 | 566 | if config.architectures[0] == "DbrxForCausalLM":
|
569 | 567 | E = config.ffn_config.moe_num_experts
|
@@ -595,11 +593,7 @@ def main(args: argparse.Namespace):
|
595 | 593 | shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
596 | 594 |
|
597 | 595 | hidden_size = config.hidden_size
|
598 |
| - dtype = ( |
599 |
| - torch.float16 |
600 |
| - if current_platform.is_rocm() |
601 |
| - else getattr(torch, config.torch_dtype) |
602 |
| - ) |
| 596 | + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype |
603 | 597 | use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
604 | 598 | use_int8_w8a16 = args.dtype == "int8_w8a16"
|
605 | 599 | block_quant_shape = get_weight_block_size_safety(config)
|
|
0 commit comments