File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -663,6 +663,13 @@ def process_weights_after_loading(self, layer):
663
663
1 , 2 ).contiguous ()
664
664
layer .w2_weight .data = layer .w2_weight .data .transpose (
665
665
1 , 2 ).contiguous ()
666
+ # This optimization relies on the modifications in torch_npu, otherwise accuracy
667
+ # problem will happen. But we can evaluate the inference speed by transforming
668
+ # weights to NZ (29)
669
+ layer .w13_weight .data = torch_npu .npu_format_cast (
670
+ layer .w13_weight .data , 29 )
671
+ layer .w2_weight .data = torch_npu .npu_format_cast (
672
+ layer .w2_weight .data , 29 )
666
673
layer .w13_weight_scale .data = layer .w13_weight_scale .data .view (
667
674
layer .w13_weight_scale .data .shape [0 ], - 1 )
668
675
layer .w13_weight_offset .data = layer .w13_weight_offset .data .view (
You can’t perform that action at this time.
0 commit comments