Skip to content

Commit e9644da

Browse files
authored
[mps] Add offsets to enable aoti (#2484)
* Update [ghstack-poisoned] * Update (base update) [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 61d49d4 commit e9644da

File tree

5 files changed

+71
-35
lines changed

5 files changed

+71
-35
lines changed

torchao/experimental/kernels/mps/src/lowbit.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ using DispatchFn =
7373
void (*)(id<MTLComputeCommandEncoder>, int32_t, int32_t, int32_t, int32_t);
7474

7575
inline void linear_lowbit_quant_weights_mps_impl(
76-
id<MTLBuffer> a_buf,
77-
id<MTLBuffer> b_buf,
78-
id<MTLBuffer> s_buf,
79-
id<MTLBuffer> z_buf,
80-
id<MTLBuffer> out_buf,
76+
std::pair<id<MTLBuffer>, size_t> a_buf_offset,
77+
std::pair<id<MTLBuffer>, size_t> b_buf_offset,
78+
std::pair<id<MTLBuffer>, size_t> s_buf_offset,
79+
std::pair<id<MTLBuffer>, size_t> z_buf_offset,
80+
std::pair<id<MTLBuffer>, size_t> out_buf_offset,
8181
int32_t M,
8282
int32_t K,
8383
int32_t N,
@@ -97,11 +97,11 @@ inline void linear_lowbit_quant_weights_mps_impl(
9797
metal_lowbit_quantized_lib.getPipelineStateForFunc(shader_func);
9898
const auto maxThreadsPerGroup = [cpl maxTotalThreadsPerThreadgroup];
9999
[computeEncoder setComputePipelineState:cpl];
100-
[computeEncoder setBuffer:a_buf offset:0 atIndex:0];
101-
[computeEncoder setBuffer:b_buf offset:0 atIndex:1];
102-
[computeEncoder setBuffer:s_buf offset:0 atIndex:2];
103-
[computeEncoder setBuffer:z_buf offset:0 atIndex:3];
104-
[computeEncoder setBuffer:out_buf offset:0 atIndex:4];
100+
[computeEncoder setBuffer:a_buf_offset.first offset:a_buf_offset.second atIndex:0];
101+
[computeEncoder setBuffer:b_buf_offset.first offset:b_buf_offset.second atIndex:1];
102+
[computeEncoder setBuffer:s_buf_offset.first offset:s_buf_offset.second atIndex:2];
103+
[computeEncoder setBuffer:z_buf_offset.first offset:z_buf_offset.second atIndex:3];
104+
[computeEncoder setBuffer:out_buf_offset.first offset:out_buf_offset.second atIndex:4];
105105
[computeEncoder setBytes:sizes.data()
106106
length:sizeof(uint32_t) * sizes.size()
107107
atIndex:5];
@@ -133,12 +133,12 @@ std::tuple<const std::string, DispatchFn> get_shader_func_and_dispatch(
133133
// LowBit Quantized Weights Linear on Metal
134134
template <int nbit>
135135
void linear_lowbit_quant_weights_mps(
136-
id<MTLBuffer> a_buf,
137-
id<MTLBuffer> b_buf,
136+
std::pair<id<MTLBuffer>, size_t> a_buf_offset,
137+
std::pair<id<MTLBuffer>, size_t> b_buf_offset,
138138
int64_t qGroupSize,
139-
id<MTLBuffer> s_buf,
140-
id<MTLBuffer> z_buf,
141-
id<MTLBuffer> out_buf,
139+
std::pair<id<MTLBuffer>, size_t> s_buf_offset,
140+
std::pair<id<MTLBuffer>, size_t> z_buf_offset,
141+
std::pair<id<MTLBuffer>, size_t> out_buf_offset,
142142
int32_t M,
143143
int32_t K,
144144
int32_t N,
@@ -154,11 +154,11 @@ void linear_lowbit_quant_weights_mps(
154154
const DispatchFn dispatch_fn = std::get<1>(shader_func_and_dispatch);
155155

156156
return linear_lowbit_quant_weights_mps_impl(
157-
a_buf,
158-
b_buf,
159-
s_buf,
160-
z_buf,
161-
out_buf,
157+
a_buf_offset,
158+
b_buf_offset,
159+
s_buf_offset,
160+
z_buf_offset,
161+
out_buf_offset,
162162
M,
163163
K,
164164
N,

torchao/experimental/kernels/mps/test/test_lowbit.mm

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ void pack() {
118118

119119
void linear() {
120120
LowBitQuantWeights<nbit>::linear(
121-
buf_A,
122-
buf_B,
121+
{buf_A, 0},
122+
{buf_B, 0},
123123
qGroupSize,
124-
buf_S,
125-
buf_Z,
126-
buf_C,
124+
{buf_S, 0},
125+
{buf_Z, 0},
126+
{buf_C, 0},
127127
M,
128128
K,
129129
N,

torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ Tensor linear_mps_kernel_out(
9797
auto K = A.size(1);
9898

9999
LowBitQuantWeights<nbit>::linear(
100-
getMTLBufferStorage(A),
101-
getMTLBufferStorage(B),
100+
{getMTLBufferStorage(A), A.storage_offset() * A.element_size()},
101+
{getMTLBufferStorage(B), B.storage_offset() * B.element_size()},
102102
group_size,
103-
getMTLBufferStorage(S),
104-
getMTLBufferStorage(Z),
105-
getMTLBufferStorage(C),
103+
{getMTLBufferStorage(S), S.storage_offset() * S.element_size()},
104+
{getMTLBufferStorage(Z), Z.storage_offset() * Z.element_size()},
105+
{getMTLBufferStorage(C), C.storage_offset() * C.element_size()},
106106
M,
107107
K,
108108
N,

torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ bool check_linear_mps_args(
9595
auto K = A.size(1);
9696

9797
torchao::kernels::mps::lowbit::LowBitQuantWeights<nbit>::linear(
98-
getMTLBufferStorage(A),
99-
getMTLBufferStorage(B),
98+
{getMTLBufferStorage(A), A.storage_offset() * A.element_size()},
99+
{getMTLBufferStorage(B), B.storage_offset() * B.element_size()},
100100
group_size,
101-
getMTLBufferStorage(S),
102-
getMTLBufferStorage(Z),
103-
getMTLBufferStorage(out),
101+
{getMTLBufferStorage(S), S.storage_offset() * S.element_size()},
102+
{getMTLBufferStorage(Z), Z.storage_offset() * Z.element_size()},
103+
{getMTLBufferStorage(out), out.storage_offset() * out.element_size()},
104104
M,
105105
K,
106106
N,

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,42 @@ def test_export(self, nbit):
8686
== f"torchao._linear_fp_act_{nbit}bit_weight.default"
8787
)
8888

89+
@parameterized.expand(BITWIDTHS)
90+
def test_export_accuracy(self, nbit):
91+
group_size = 32
92+
m = 3
93+
n = 12
94+
k = 64
95+
with torch.no_grad():
96+
activations = torch.rand(m, k, dtype=torch.float32, device="mps")
97+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
98+
99+
# Compute expected result
100+
weight_cpu = model[0].weight.data
101+
weight_qvals_cpu, weight_scales_cpu, weight_zeros_cpu = _quantize(
102+
weight_cpu, group_size, nbit, True, torch.uint8
103+
)
104+
weight_zeros_cpu = -weight_zeros_cpu * weight_scales_cpu
105+
expected = self._reference_linear_lowbit_quant_weights(
106+
activations.cpu(),
107+
weight_qvals_cpu,
108+
group_size,
109+
weight_scales_cpu,
110+
weight_zeros_cpu,
111+
)
112+
113+
quantized_model = self._quantize_model(
114+
model, torch.float32, nbit, group_size
115+
)
116+
117+
ep = torch.export.export(quantized_model, (activations,), strict=True)
118+
path = torch._inductor.aoti_compile_and_package(ep)
119+
compiled_model = torch._inductor.aoti_load_package(path)
120+
result = compiled_model(activations)
121+
122+
# Compare results
123+
torch.testing.assert_close(result.cpu(), expected, rtol=0.001, atol=0.001)
124+
89125
@parameterized.expand(BITWIDTHS)
90126
def test_2d_output_device_and_shape(self, nbit):
91127
model, group_size, k0, n = self._model_setup()

0 commit comments

Comments
 (0)