Skip to content

Commit 26ade98

Browse files
committed
Update
[ghstack-poisoned]
2 parents 681277a + f76e10b commit 26ade98

File tree

4 files changed

+25
-9
lines changed

4 files changed

+25
-9
lines changed

.github/workflows/1xL4_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ jobs:
5151
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
5252
./test/float8/test_everything_single_gpu.sh
5353
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
54+
python test/kernel/test_blockwise_triton.py --verbose -s

torchao/float8/inference.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,16 @@ def _is_128_128_scaled(x: torch.Tensor) -> bool:
217217
return len(b) == 2 and b[0] == 128 and b[1] == 128
218218

219219

220+
def _granularity_is_a_1_128_w_128_128(
221+
g: Union[
222+
FP8Granularity,
223+
Tuple[FP8Granularity, FP8Granularity],
224+
list[FP8Granularity],
225+
],
226+
) -> bool:
227+
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
228+
229+
220230
def _normalize_granularity(
221231
granularity: Optional[
222232
Union[
@@ -238,9 +248,7 @@ def _normalize_granularity(
238248
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
239249
granularity[1], PerRow
240250
)
241-
is_a_1_128_w_128_128 = granularity[0] == PerBlock((1, 128)) and granularity[
242-
1
243-
] == PerBlock((128, 128))
251+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
244252

245253
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
246254
raise ValueError(f"Unsupported granularity types: {granularity}.")
@@ -273,9 +281,7 @@ def _check_hardware_support(
273281
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
274282
granularities[1], PerRow
275283
)
276-
is_a_1_128_w_128_128 = granularities[0] == PerBlock((1, 128)) and granularities[
277-
1
278-
] == PerBlock((128, 128))
284+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
279285

280286
if is_per_tensor or is_per_row:
281287
assert is_sm_at_least_89() or is_MI300(), (

torchao/quantization/quant_api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
Float8MMConfig,
6363
FP8Granularity,
6464
_check_hardware_support,
65+
_granularity_is_a_1_128_w_128_128,
6566
_normalize_granularity,
6667
)
6768
from torchao.quantization.linear_activation_weight_observed_tensor import (
@@ -1770,13 +1771,22 @@ def __post_init__(self):
17701771
torch._C._log_api_usage_once(
17711772
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
17721773
)
1773-
if self.mm_config is None:
1774-
self.mm_config = Float8MMConfig(use_fast_accum=True)
17751774
activation_granularity, weight_granularity = _normalize_granularity(
17761775
self.granularity
17771776
)
17781777
self.granularity = [activation_granularity, weight_granularity]
17791778

1779+
default_use_fast_accum = True
1780+
if _granularity_is_a_1_128_w_128_128(self.granularity):
1781+
assert self.activation_value_lb is None, "unimplemented"
1782+
assert self.activation_value_ub is None, "unimplemented"
1783+
assert self.kernel_preference is KernelPreference.TORCH, "unimplemented"
1784+
assert self.mm_config is None, "unimplemented"
1785+
default_use_fast_accum = False
1786+
1787+
if self.mm_config is None:
1788+
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
1789+
17801790

17811791
# for bc
17821792
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def _(func, types, args, kwargs):
352352
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
353353

354354
if _is_128_128_scaled(weight_tensor):
355-
# TODO(before land): ensure fast_accum is False for blockwise
356355
# TODO(future PR): add testing for torch._scaled_mm with
357356
# blockwise scaling on CUDA 12.9
358357
# TODO(future PR): add fbgemm_gpu_genai path if available

0 commit comments

Comments
 (0)