Skip to content

Commit 7513042

Browse files
metal lowbit kernels: add mps_ops_lib (#2059)
1 parent 88cd9c7 commit 7513042

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch import Tensor
9+
from torch.library import impl
10+
11+
torchao_lib = torch.library.Library("torchao", "IMPL")
12+
for nbit in range(1, 8):
13+
14+
@impl(torchao_lib, f"_linear_fp_act_{nbit}bit_weight", "Meta")
15+
def _(
16+
activations: Tensor,
17+
packed_weights: Tensor,
18+
group_size: int,
19+
scales: int,
20+
zeros: int,
21+
):
22+
assert activations.dtype in [torch.float32, torch.float16, torch.bfloat16]
23+
assert activations.is_contiguous()
24+
assert activations.dim() == 2
25+
26+
assert packed_weights.dtype == torch.uint8
27+
assert packed_weights.is_contiguous()
28+
29+
m = activations.size(0)
30+
k = activations.size(1)
31+
n = packed_weights.size(0)
32+
33+
assert k % 8 == 0
34+
assert n % 4 == 0
35+
36+
assert group_size in [32, 64, 128, 256]
37+
38+
assert scales.is_contiguous()
39+
assert scales.dim() == 2
40+
assert scales.size(1) == n
41+
42+
assert zeros.is_contiguous()
43+
assert zeros.dim() == 2
44+
assert zeros.size(1) == n
45+
46+
return torch.empty(m, n, dtype=activations.dtype, device="meta")

0 commit comments

Comments
 (0)