Skip to content

Commit 841d104

Browse files
committed
move float8 blockwise kernels out of prototype
Summary: These will be useful as a fallback path for a_1_128_w_128_128 (DeepSeek) scaling support for float8 inference. Bringing out of prototype folder. Test Plan: ``` pytest test/kernel/test_blockwise_triton.py -s -x python benchmarks/benchmark_blockwise_scaled_linear_triton.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0fb2407 ghstack-comment-id: 3460951786 Pull-Request: #3256
1 parent 0ac0305 commit 841d104

File tree

5 files changed

+6
-5
lines changed

5 files changed

+6
-5
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
16+
from torchao.kernel.blockwise_quantization import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,

test/prototype/test_blockwise_triton.py renamed to test/kernel/test_blockwise_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
triton = pytest.importorskip("triton", reason="Triton required to run this test")
1313

14-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
14+
from torchao.kernel.blockwise_quantization import (
1515
blockwise_fp8_gemm,
1616
fp8_blockwise_act_quant,
1717
fp8_blockwise_weight_dequant,

torchao/prototype/blockwise_fp8_inference/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from .blockwise_linear import BlockwiseQuantLinear
2-
from .blockwise_quantization import (
1+
from torchao.kernel.blockwise_quantization import (
32
blockwise_fp8_gemm,
43
fp8_blockwise_act_quant,
54
fp8_blockwise_weight_dequant,
65
fp8_blockwise_weight_quant,
76
)
87

8+
from .blockwise_linear import BlockwiseQuantLinear
9+
910
__all__ = [
1011
"blockwise_fp8_gemm",
1112
"BlockwiseQuantLinear",

torchao/prototype/blockwise_fp8_inference/blockwise_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
10+
from torchao.kernel.blockwise_quantization import (
1111
blockwise_fp8_gemm,
1212
fp8_blockwise_act_quant,
1313
)

0 commit comments

Comments
 (0)