Skip to content

Commit 6b80c5a

Browse files
GDzhu01qyqc731
andauthored
Fix W8A8 fused moe bug (#1529)
### What this PR does / why we need it? 1. drop some useless code for w8a8 fusedmoe 2. Add in8 kv cache check 3. Add more ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: zhuyilin <809721801@qq.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
1 parent 7fc1a98 commit 6b80c5a

File tree

8 files changed

+1623
-53
lines changed

8 files changed

+1623
-53
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 499 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import torch
6+
from vllm.attention.layer import Attention
7+
from vllm.model_executor.layers.fused_moe import FusedMoE
8+
from vllm.model_executor.layers.linear import (LinearBase,
9+
UnquantizedLinearMethod)
10+
11+
from tests.ut.base import TestBase
12+
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
13+
AscendQuantConfig)
14+
15+
ASCEND_QUATIZATION_METHOD = "ascend"
16+
17+
18+
class TestAscendQuantConfig(TestBase):
19+
20+
def setUp(self):
21+
self.sample_config = {
22+
"weight": "INT8",
23+
"fa_quant_type": "C8",
24+
"kv_quant_type": "C8",
25+
"layer1.weight": "INT8",
26+
"layer2.weight": "FLOAT",
27+
"fused_layer.weight": "FLOAT",
28+
"fused_layer.shard1.weight": "FLOAT",
29+
"fused_layer.shard2.weight": "FLOAT",
30+
"shard1.weight": "FLOAT",
31+
"shard2.weight": "FLOAT",
32+
}
33+
self.ascend_config = AscendQuantConfig(self.sample_config)
34+
self.ascend_config.packed_modules_mapping = None
35+
36+
def test_init(self):
37+
self.assertEqual(self.ascend_config.quant_description,
38+
self.sample_config)
39+
40+
def test_repr(self):
41+
repr_str = repr(self.ascend_config)
42+
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
43+
44+
def test_get_name(self):
45+
self.assertEqual(AscendQuantConfig.get_name(),
46+
ASCEND_QUATIZATION_METHOD)
47+
48+
def test_get_supported_act_dtypes(self):
49+
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
50+
self.assertEqual(len(supported_dtypes), 3)
51+
52+
def test_get_min_capability(self):
53+
with self.assertRaises(NotImplementedError):
54+
AscendQuantConfig.get_min_capability()
55+
56+
def test_get_config_filenames(self):
57+
filenames = AscendQuantConfig.get_config_filenames()
58+
self.assertEqual(filenames, ["quant_model_description.json"])
59+
60+
def test_from_config(self):
61+
config = AscendQuantConfig.from_config(self.sample_config)
62+
self.assertIsInstance(config, AscendQuantConfig)
63+
self.assertEqual(config.quant_description, self.sample_config)
64+
65+
@patch('torch.npu.is_available')
66+
def test_override_quantization_method(self, mock_is_available):
67+
# Test when NPU is available
68+
mock_is_available.return_value = True
69+
result = AscendQuantConfig.override_quantization_method(None, None)
70+
self.assertEqual(result, ASCEND_QUATIZATION_METHOD)
71+
72+
# Test when NPU is not available
73+
mock_is_available.return_value = False
74+
result = AscendQuantConfig.override_quantization_method(None, None)
75+
self.assertIsNone(result)
76+
77+
def test_get_quant_method_for_linear(self):
78+
linear_layer = MagicMock(spec=LinearBase)
79+
# Test skipped layer
80+
with patch.object(self.ascend_config,
81+
'is_layer_skipped_ascend',
82+
return_value=True):
83+
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
84+
self.assertIsInstance(method, UnquantizedLinearMethod)
85+
86+
# Test quantized layer
87+
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
88+
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
89+
90+
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
91+
self.assertIs(method, mock_ascend_linear.return_value)
92+
mock_ascend_linear.assert_called_once_with(
93+
self.ascend_config, ".attn",
94+
self.ascend_config.packed_modules_mapping)
95+
96+
def test_get_quant_method_for_attention(self):
97+
attention_layer = MagicMock(spec=Attention)
98+
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
99+
return_value=MagicMock()) as mock_ascend_kvcache:
100+
# Test with fa_quant_type
101+
method = self.ascend_config.get_quant_method(
102+
attention_layer, ".attn")
103+
self.assertIs(method, mock_ascend_kvcache.return_value)
104+
105+
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
106+
return_value=MagicMock()) as mock_ascend_kvcache:
107+
# Test with kv_quant_type
108+
modified_config = {"kv_quant_type": "C8"}
109+
config = AscendQuantConfig(modified_config)
110+
config.packed_modules_mapping = None
111+
method = config.get_quant_method(attention_layer, "attn")
112+
self.assertIs(method, mock_ascend_kvcache.return_value)
113+
114+
def test_get_quant_method_for_fused_moe(self):
115+
fused_moe_layer = MagicMock(spec=FusedMoE)
116+
117+
# Test skipped layer
118+
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
119+
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
120+
method = self.ascend_config.get_quant_method(
121+
fused_moe_layer, "moe_layer")
122+
self.assertIs(method, mock_ascend_moe.return_value)
123+
124+
# Test quantized layer
125+
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
126+
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
127+
method = self.ascend_config.get_quant_method(
128+
fused_moe_layer, "moe_layer")
129+
self.assertIs(method, mock_ascend_moe.return_value)
130+
131+
def test_is_layer_skipped_ascend(self):
132+
# Test non-fused layer that should be quantized
133+
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
134+
135+
# Test non-fused layer that should be skipped
136+
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
137+
138+
# Test fused layer
139+
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
140+
self.assertTrue(
141+
self.ascend_config.is_layer_skipped_ascend("fused_layer",
142+
fused_mapping))
143+
144+
# Test inconsistent fused layer shards
145+
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
146+
config = AscendQuantConfig(bad_config)
147+
with self.assertRaises(ValueError):
148+
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
149+
150+
def test_get_scaled_act_names(self):
151+
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
152+
153+
154+
class TestAscendKVCacheMethod(TestBase):
155+
156+
def setUp(self):
157+
# Setup common test fixtures
158+
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
159+
self.mock_quant_config.quant_description = {"some_config": "value"}
160+
self.prefix = "attention_layer"
161+
162+
# Mock the quantizer and quant_method
163+
self.mock_quantizer = MagicMock()
164+
self.mock_quant_method = MagicMock()
165+
166+
# Patch the AscendQuantizer
167+
self.quantizer_patcher = patch(
168+
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
169+
return_value=self.mock_quantizer)
170+
self.mock_get_quantizer = self.quantizer_patcher.start()
171+
172+
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
173+
174+
# Create instance
175+
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
176+
self.prefix)
177+
178+
def tearDown(self):
179+
self.quantizer_patcher.stop()
180+
181+
def test_init(self):
182+
"""Test initialization with proper quantizer setup."""
183+
self.mock_get_quantizer.assert_called_once_with(
184+
self.mock_quant_config.quant_description, self.prefix)
185+
self.mock_quantizer.build_attention_method.assert_called_once()
186+
187+
def test_create_weights(self):
188+
"""Test create_weights delegates to quant_method."""
189+
mock_layer = MagicMock()
190+
self.kv_cache_method.create_weights(mock_layer)
191+
self.mock_quant_method.create_weights.assert_called_once_with(
192+
mock_layer)
193+
194+
def test_process_weights_after_loading_with_method(self):
195+
"""Test process_weights when quant_method has the method."""
196+
mock_layer = MagicMock()
197+
self.kv_cache_method.process_weights_after_loading(mock_layer)
198+
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
199+
mock_layer)
200+
201+
def test_process_weights_after_loading_without_method(self):
202+
"""Test process_weights when quant_method lacks the method."""
203+
# Reset mock to remove the method
204+
del self.mock_quant_method.process_weights_after_loading
205+
mock_layer = MagicMock()
206+
207+
# Should not raise exception
208+
self.kv_cache_method.process_weights_after_loading(mock_layer)
209+
210+
def test_apply_delegation(self):
211+
"""Test apply properly delegates to quant_method."""
212+
mock_layer = MagicMock()
213+
mock_query = torch.randn(1, 32, 128)
214+
mock_key = torch.randn(1, 32, 128)
215+
mock_value = torch.randn(1, 32, 128)
216+
mock_kv_cache = MagicMock()
217+
mock_attn_metadata = MagicMock()
218+
mock_scale = 1.0
219+
mock_output = torch.zeros(1, 32, 128)
220+
mock_attn_type = MagicMock()
221+
expected_result = torch.randn(1, 32, 128)
222+
self.mock_quant_method.apply.return_value = expected_result
223+
224+
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
225+
mock_value, mock_kv_cache,
226+
mock_attn_metadata, mock_attn_type,
227+
mock_scale, mock_output)
228+
229+
self.mock_quant_method.apply.assert_called_once_with(
230+
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
231+
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
232+
self.assertTrue(torch.equal(result, expected_result))
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
from tests.ut.base import TestBase
6+
from vllm_ascend.quantization.quant_config import AscendQuantConfig
7+
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
8+
W8A8Quantizer)
9+
10+
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
11+
12+
13+
class TestGetQuantizer(TestBase):
14+
15+
def setUp(self):
16+
# Setup common test fixtures
17+
self.supported_types = {
18+
'INT8': MagicMock(_instance=None),
19+
'FP16': MagicMock(_instance=None),
20+
'C8': MagicMock(_instance=None)
21+
}
22+
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
23+
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
24+
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
25+
self.mock_quant_config.quant_description = {"some_config": "value"}
26+
27+
def tearDown(self):
28+
# Restore original supported types
29+
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
30+
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
31+
32+
def test_get_quantizer_fa(self):
33+
"""Test successful quantizer retrieval for different cases."""
34+
# Setup
35+
quant_description = {'fa_quant_type': 'C8'}
36+
prefix = '.attn'
37+
expected_type = 'C8'
38+
with patch.dict(
39+
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
40+
SUPPORT_ASCEND_QUANTIZER_TYPE):
41+
42+
result = VLLMAscendQuantizer.get_quantizer(
43+
quant_description,
44+
prefix,
45+
packed_modules_mapping={"some": "mapping"})
46+
47+
# Verify
48+
self.assertIsNotNone(result)
49+
self.assertEqual(result,
50+
self.supported_types[expected_type]._instance)
51+
self.supported_types[expected_type].assert_called_once_with(
52+
quant_description)
53+
54+
def test_get_quantizer_kv(self):
55+
"""Test successful quantizer retrieval for different cases."""
56+
# Setup
57+
quant_description = {'kv_quant_type': 'C8'}
58+
prefix = '.attn'
59+
expected_type = 'C8'
60+
with patch.dict(
61+
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
62+
SUPPORT_ASCEND_QUANTIZER_TYPE):
63+
64+
result = VLLMAscendQuantizer.get_quantizer(
65+
quant_description,
66+
prefix,
67+
packed_modules_mapping={"some": "mapping"})
68+
69+
# Verify
70+
self.assertIsNotNone(result)
71+
self.assertEqual(result,
72+
self.supported_types[expected_type]._instance)
73+
self.supported_types[expected_type].assert_called_once_with(
74+
quant_description)
75+
76+
def test_get_quantizer_linear(self):
77+
"""Test successful quantizer retrieval for different cases."""
78+
# Setup
79+
quant_description = {'linear_type': 'INT8'}
80+
prefix = 'nothing'
81+
expected_type = 'INT8'
82+
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
83+
return_value=expected_type), \
84+
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
85+
86+
result = VLLMAscendQuantizer.get_quantizer(
87+
quant_description,
88+
prefix,
89+
packed_modules_mapping={"some": "mapping"})
90+
91+
# Verify
92+
self.assertIsNotNone(result)
93+
self.assertEqual(result,
94+
self.supported_types[expected_type]._instance)
95+
self.supported_types[expected_type].assert_called_once_with(
96+
quant_description)
97+
98+
99+
class TestW8A8Quantizer(TestBase):
100+
101+
def setUp(self):
102+
self.quantizer = W8A8Quantizer(quant_description={})
103+
104+
def test_build_linear_method(self):
105+
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
106+
return_value=MagicMock()) as mock_linear:
107+
result = self.quantizer.build_linear_method()
108+
mock_linear.assert_called_once_with()
109+
self.assertIsInstance(result, MagicMock)
110+
111+
def test_build_moe_method(self):
112+
with patch(
113+
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
114+
return_value=MagicMock()) as mock_linear:
115+
result = self.quantizer.build_moe_method()
116+
mock_linear.assert_called_once_with()
117+
self.assertIsInstance(result, MagicMock)
118+
119+
def test_build_attention_method(self):
120+
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
121+
return_value=MagicMock()) as mock_linear:
122+
result = self.quantizer.build_attention_method()
123+
mock_linear.assert_called_once_with()
124+
self.assertIsInstance(result, MagicMock)

0 commit comments

Comments
 (0)