|
20 | 20 | import fastdeploy
|
21 | 21 | from fastdeploy.distributed.communication_op import \
|
22 | 22 | tensor_model_parallel_all_reduce
|
23 |
| -from fastdeploy.model_executor.layers.utils import get_tensor |
| 23 | +from fastdeploy.model_executor.layers.utils import (create_and_set_parameter, |
| 24 | + get_tensor) |
24 | 25 | from fastdeploy.utils import ceil_div
|
25 | 26 |
|
26 | 27 | from ..quantization.quant_base import QuantMethodBase
|
@@ -191,7 +192,7 @@ def apply(
|
191 | 192 |
|
192 | 193 | ffn2_input = paddle.incubate.nn.functional.swiglu(
|
193 | 194 | ffn1_out)
|
194 |
| - |
| 195 | + |
195 | 196 | ffn2_out = paddle.empty(
|
196 | 197 | (token_num * top_k, hidden_size),
|
197 | 198 | dtype=x.dtype,
|
@@ -484,3 +485,220 @@ def apply(
|
484 | 485 | tensor_model_parallel_all_reduce(out)
|
485 | 486 |
|
486 | 487 | return out
|
| 488 | + |
| 489 | +class BlockWiseFP8MoEMethod(QuantMethodBase): |
| 490 | + """ |
| 491 | + Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE. |
| 492 | + """ |
| 493 | + |
| 494 | + def __init__(self, quant_config): |
| 495 | + """ |
| 496 | + Triton Group Gemm to compute Fused MoE. |
| 497 | + """ |
| 498 | + self.quant_config = quant_config |
| 499 | + self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"] |
| 500 | + self.added_scale_attrs = [ |
| 501 | + "moe_ffn1_weight_scale", "moe_ffn2_weight_scale" |
| 502 | + ] |
| 503 | + |
| 504 | + def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: |
| 505 | + """process_prequanted_weights""" |
| 506 | + |
| 507 | + raise NotImplementedError() |
| 508 | + |
| 509 | + def create_weights(self, layer: nn.Layer, state_dict): |
| 510 | + """ |
| 511 | + Triton MoE create weight process. |
| 512 | + """ |
| 513 | + ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) |
| 514 | + |
| 515 | + self.check(layer, ffn1_weights, ffn2_weights) |
| 516 | + |
| 517 | + for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): |
| 518 | + weight_name = self.added_weight_attrs[idx] |
| 519 | + scale_name = self.added_scale_attrs[idx] |
| 520 | + |
| 521 | + weight_list = [] |
| 522 | + weight_scale_list = [] |
| 523 | + for i in range(layer.num_local_experts): |
| 524 | + from fastdeploy.model_executor.layers.utils import \ |
| 525 | + per_block_cast_to_fp8 |
| 526 | + quant_weight, scale = per_block_cast_to_fp8( |
| 527 | + weight_tensor[i], self.quant_config.weight_block_size) |
| 528 | + |
| 529 | + weight_list.append(quant_weight) |
| 530 | + weight_scale_list.append(scale) |
| 531 | + quanted_weight = paddle.stack(weight_list, axis=0) |
| 532 | + quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous() |
| 533 | + create_and_set_parameter(layer, weight_name, quanted_weight) |
| 534 | + |
| 535 | + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) |
| 536 | + quanted_weight_scale = quanted_weight_scale.transpose( |
| 537 | + [0, 2, 1]).contiguous() |
| 538 | + create_and_set_parameter(layer, scale_name, quanted_weight_scale) |
| 539 | + |
| 540 | + def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights): |
| 541 | + """ |
| 542 | + check layer is valid for this method |
| 543 | + """ |
| 544 | + assert ffn1_weights[0].shape == [ |
| 545 | + layer.hidden_size, layer.moe_intermediate_size * 2 |
| 546 | + ] |
| 547 | + assert ffn2_weights[0].shape == [ |
| 548 | + layer.moe_intermediate_size, layer.hidden_size |
| 549 | + ] |
| 550 | + |
| 551 | + def apply( |
| 552 | + self, |
| 553 | + layer: nn.Layer, |
| 554 | + x: paddle.Tensor, |
| 555 | + gate_out: paddle.Tensor, |
| 556 | + ) -> paddle.Tensor: |
| 557 | + """ |
| 558 | + Triton compute Fused MoE. |
| 559 | + """ |
| 560 | + |
| 561 | + token_num = x.shape[0] |
| 562 | + top_k = layer.top_k |
| 563 | + num_local_experts = layer.num_local_experts |
| 564 | + moe_intermediate_size = layer.moe_intermediate_size |
| 565 | + hidden_size = layer.hidden_size |
| 566 | + E, N1, _ = layer.moe_ffn1_weight.shape |
| 567 | + N2 = layer.moe_ffn2_weight.shape[1] |
| 568 | + |
| 569 | + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( |
| 570 | + gate_out, |
| 571 | + layer.gate_correction_bias, |
| 572 | + layer.top_k, |
| 573 | + True, # apply_norm_weight |
| 574 | + False, |
| 575 | + ) |
| 576 | + |
| 577 | + config = { |
| 578 | + "BLOCK_SIZE_M": 64, |
| 579 | + "BLOCK_SIZE_N": self.quant_config.weight_block_size[1], |
| 580 | + "BLOCK_SIZE_K": self.quant_config.weight_block_size[0], |
| 581 | + "GROUP_SIZE_M": 32, |
| 582 | + "num_warps": 4, |
| 583 | + "num_stages": 3, |
| 584 | + } |
| 585 | + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess |
| 586 | + |
| 587 | + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( |
| 588 | + topk_ids, num_local_experts, config["BLOCK_SIZE_M"]) |
| 589 | + max_num_tokens_padded = sorted_token_ids.shape[0] |
| 590 | + |
| 591 | + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * |
| 592 | + ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) |
| 593 | + |
| 594 | + from .triton_moe_kernels import fused_moe_kernel_paddle |
| 595 | + |
| 596 | + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( |
| 597 | + x, self.quant_config.weight_block_size[0]) |
| 598 | + |
| 599 | + cache13 = paddle.empty([token_num * top_k * max(N1, N2)], |
| 600 | + dtype=x.dtype) |
| 601 | + intermediate_cache1 = cache13[:token_num * top_k * N1].view( |
| 602 | + [token_num * top_k, N1]) |
| 603 | + intermediate_cache3 = cache13[:token_num * top_k * N2].view( |
| 604 | + [token_num * top_k, N2]) |
| 605 | + |
| 606 | + fused_moe_kernel_paddle[grid]( |
| 607 | + x_q, |
| 608 | + layer.moe_ffn1_weight.view(paddle.float8_e4m3fn), |
| 609 | + intermediate_cache1, |
| 610 | + x_scale, |
| 611 | + layer.moe_ffn1_weight_scale, |
| 612 | + None, |
| 613 | + sorted_token_ids, |
| 614 | + expert_ids, |
| 615 | + num_tokens_post_padded, |
| 616 | + max_num_tokens_padded, |
| 617 | + token_num * top_k, |
| 618 | + N=moe_intermediate_size * 2, |
| 619 | + K=hidden_size, |
| 620 | + stride_am=x_q.strides[0], |
| 621 | + stride_ak=x_q.strides[1], |
| 622 | + stride_be=layer.moe_ffn1_weight.strides[0], |
| 623 | + stride_bk=layer.moe_ffn1_weight.strides[2], |
| 624 | + stride_bn=layer.moe_ffn1_weight.strides[1], |
| 625 | + stride_cm=intermediate_cache1.strides[0], |
| 626 | + stride_cn=intermediate_cache1.strides[1], |
| 627 | + # |
| 628 | + stride_asm=x_scale.strides[0], # only used in blockwise fp8 |
| 629 | + stride_ask=x_scale.strides[1], # only used in blockwise fp8 |
| 630 | + stride_bse=layer.moe_ffn1_weight_scale.strides[0], |
| 631 | + stride_bsk=layer.moe_ffn1_weight_scale.strides[2], |
| 632 | + stride_bsn=layer.moe_ffn1_weight_scale.strides[1], |
| 633 | + group_n=self.quant_config.weight_block_size[1], |
| 634 | + group_k=self.quant_config.weight_block_size[0], |
| 635 | + # Meta-parameters |
| 636 | + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], |
| 637 | + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], |
| 638 | + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], |
| 639 | + GROUP_SIZE_M=config["GROUP_SIZE_M"], |
| 640 | + MUL_ROUTED_WEIGHT=False, |
| 641 | + top_k=top_k, |
| 642 | + compute_type_enum=1, |
| 643 | + use_fp8_w8a8=True, |
| 644 | + use_int8_w8a16=False, |
| 645 | + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, |
| 646 | + ) |
| 647 | + |
| 648 | + intermediate_cache2 = paddle.incubate.nn.functional.swiglu( |
| 649 | + intermediate_cache1) |
| 650 | + |
| 651 | + grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) * |
| 652 | + ceil_div(hidden_size, config["BLOCK_SIZE_N"]), ) |
| 653 | + |
| 654 | + x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant( |
| 655 | + intermediate_cache2, self.quant_config.weight_block_size[0]) |
| 656 | + |
| 657 | + fused_moe_kernel_paddle[grid]( |
| 658 | + x_q, |
| 659 | + layer.moe_ffn2_weight.view(paddle.float8_e4m3fn), |
| 660 | + intermediate_cache3, |
| 661 | + x_scale, |
| 662 | + layer.moe_ffn2_weight_scale, |
| 663 | + topk_weights, |
| 664 | + sorted_token_ids, |
| 665 | + expert_ids, |
| 666 | + num_tokens_post_padded, |
| 667 | + max_num_tokens_padded, |
| 668 | + token_num * top_k, |
| 669 | + N=hidden_size, |
| 670 | + K=moe_intermediate_size, |
| 671 | + stride_am=x_q.strides[0], |
| 672 | + stride_ak=x_q.strides[1], |
| 673 | + stride_be=layer.moe_ffn2_weight.strides[0], |
| 674 | + stride_bk=layer.moe_ffn2_weight.strides[2], |
| 675 | + stride_bn=layer.moe_ffn2_weight.strides[1], |
| 676 | + stride_cm=intermediate_cache3.strides[0], |
| 677 | + stride_cn=intermediate_cache3.strides[1], |
| 678 | + stride_asm=x_scale.strides[0], # only used in blockwise fp8 |
| 679 | + stride_ask=x_scale.strides[1], # only used in blockwise fp8 |
| 680 | + stride_bse=layer.moe_ffn2_weight_scale.strides[0], |
| 681 | + stride_bsk=layer.moe_ffn2_weight_scale.strides[2], |
| 682 | + stride_bsn=layer.moe_ffn2_weight_scale.strides[1], |
| 683 | + group_n=self.quant_config.weight_block_size[1], |
| 684 | + group_k=self.quant_config.weight_block_size[0], |
| 685 | + # Meta-parameters |
| 686 | + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], |
| 687 | + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], |
| 688 | + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], |
| 689 | + GROUP_SIZE_M=config["GROUP_SIZE_M"], |
| 690 | + MUL_ROUTED_WEIGHT=True, |
| 691 | + top_k=1, |
| 692 | + compute_type_enum=1, |
| 693 | + use_fp8_w8a8=True, |
| 694 | + use_int8_w8a16=False, |
| 695 | + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, |
| 696 | + ) |
| 697 | + |
| 698 | + intermediate_cache3.reshape_([token_num, top_k, hidden_size]) |
| 699 | + out = intermediate_cache3.sum(axis=1) |
| 700 | + |
| 701 | + if layer.tp_size > 1: |
| 702 | + tensor_model_parallel_all_reduce(out) |
| 703 | + |
| 704 | + return out |
0 commit comments