@@ -318,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
318
318
expert_num_tokens : torch .Tensor , # [E]
319
319
compute_type : tl .dtype ,
320
320
# Quantization data
321
- A_scale : torch .Tensor , # Optional
322
- B_scale : torch .Tensor , # Optional
321
+ A_scale : Optional [ torch .Tensor ],
322
+ B_scale : Optional [ torch .Tensor ],
323
323
B_zp : torch .Tensor ,
324
324
# Quantization schemes
325
325
use_fp8_w8a8 : bool ,
@@ -453,61 +453,18 @@ def prepare(
453
453
dtype = b_type ,
454
454
device = a1 .device )
455
455
456
- if quant_config .quant_dtype is not None :
457
- if quant_config .block_shape is not None :
458
- _ , block_k = quant_config .block_shape
459
- k_tiles = (hidden_dim + block_k - 1 ) // block_k
460
- scale_shape = (num_local_experts , self .max_num_tokens , k_tiles )
461
- else :
462
- if quant_config .per_act_token_quant :
463
- num = self .max_num_tokens
464
- else :
465
- num = 1
466
- scale_shape = (num_local_experts , num , 1 )
456
+ b_a1_scale = None
467
457
468
- #print(f"SCALE_SHAPE {block_shape} {b_a1.shape} {scale_shape}")
469
-
470
- b_a1_scale = torch .zeros (scale_shape ,
471
- dtype = torch .float32 ,
472
- device = a1 .device )
473
- else :
474
- assert a1_scale is None
475
- b_a1_scale = None
458
+ assert quant_config .quant_dtype is None , "quantization NYI"
476
459
477
460
first_expert = num_local_experts * self .rank
478
461
last_expert = first_expert + num_local_experts
479
462
480
463
for expert_id in range (first_expert , last_expert ):
481
464
topks = torch .any (topk_ids == expert_id , dim = 1 ).flatten ()
482
465
rows = torch .count_nonzero (topks .flatten ())
483
- rhs = a1 [:topks .numel ()][topks ]
484
466
idx = expert_id - first_expert
485
- if quant_config .quant_dtype is not None :
486
- if a1_scale is not None :
487
- assert False , "NYI"
488
- rhs_a1_scale = a1_scale [:topks .numel ()][topks ]
489
- else :
490
- rhs_a1_scale = None
491
- b_a1 [idx , :rows , :], b_s = moe_kernel_quantize_input (
492
- rhs ,
493
- rhs_a1_scale ,
494
- quant_config .quant_dtype ,
495
- quant_config .per_act_token_quant ,
496
- quant_config .block_shape ,
497
- )
498
- assert b_s is not None
499
- if (quant_config .block_shape is None
500
- and not quant_config .per_act_token_quant ):
501
- print (f"SCALE { idx } , { b_a1_scale [idx , :].shape } { b_s .shape } " )
502
- b_a1_scale [idx , :] = b_s
503
- else :
504
- #print(f"XXXXX rhs={rhs.shape} b_s={b_s.shape}")
505
- assert rows == b_s .shape [0 ] and b_a1_scale .shape [
506
- - 1 ] == b_s .shape [- 1 ]
507
- b_a1_scale [idx , :rows ] = b_s
508
- else :
509
- b_a1 [idx , :rows , :] = rhs
510
-
467
+ b_a1 [idx , :rows , :] = a1 [:topks .numel ()][topks ]
511
468
tokens_per_expert [idx ] = rows
512
469
513
470
assert b_a1_scale is None or b_a1_scale .ndim == 3
0 commit comments