Skip to content

Commit bc30c2a

Browse files
authored
Fix MX + vllm (#2458)
stack-info: PR: #2458, branch: drisspg/stack/82
1 parent 6dfba04 commit bc30c2a

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

test/integration/test_vllm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
4242
from vllm import LLM, SamplingParams
4343

44+
from torchao.prototype.mx_formats import MXFPInferenceConfig
4445
from torchao.quantization.granularity import PerRow, PerTensor
4546
from torchao.quantization.quant_api import (
4647
CutlassInt4PackedLayout,
@@ -69,9 +70,7 @@ def get_tests() -> List[TorchAoConfig]:
6970
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
7071
)
7172
]
72-
SM100_TESTS = [
73-
# TorchAoConfig(MXFPInferenceConfig())
74-
] # Failing for : https://github.com/pytorch/ao/issues/2239
73+
SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())]
7574

7675
# Check CUDA availability first
7776
if not torch.cuda.is_available():

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def mx_slice(func, types, args, kwargs):
240240

241241
if dim == 0:
242242
# Slicing along the first dimension (rows) TODO assuming that dim 1 is reduciton dim for now
243-
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step).flatten()
244-
sliced_data = aten.slice.Tensor(x._data, dim, start, end, step)
243+
sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step)
244+
sliced_data = aten.slice.Tensor(x._data, dim, start, end, step).unsqueeze(-1)
245245
elif dim == 1:
246246
# Slicing along reduciton dim
247247
if start is not None:
@@ -265,7 +265,7 @@ def mx_slice(func, types, args, kwargs):
265265
# Slice the scale tensor accordingly
266266
sliced_scale = aten.slice.Tensor(
267267
scale_shaped, 1, start_block, end_block, step
268-
).flatten()
268+
).unsqueeze(-1)
269269
else:
270270
raise ValueError(
271271
f"MXTensor only supports slicing along dimensions 0 and 1, got dim={dim}"

0 commit comments

Comments
 (0)