Skip to content

Commit e5ca515

Browse files
authored
Add kernel conifg for LUT based low bit quantization.
Differential Revision: D77616131 Pull Request resolved: #2533
1 parent c8d3e93 commit e5ca515

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <torchao/experimental/ops/library.h>
9+
#include <array>
10+
#include <cassert>
11+
#include <vector>
12+
13+
namespace torchao::ops::groupwise_lowbit_weight_lut {
14+
15+
constexpr int kMaxConfigs = 4;
16+
17+
/**
18+
* @brief Defines the configuration for a Universal Kernel (UKernel) for the
19+
* groupwise low-bit LUT-based kernel.
20+
*/
21+
struct UKernelConfig {
22+
// Calculates the required size for the packed activation.
23+
using packed_activations_size_fn_type =
24+
size_t (*)(int m, int k, int mr, int kr, int sr);
25+
26+
// Calculates the required size for the packed weights buffer.
27+
using packed_weights_size_fn_type = size_t (*)(
28+
int n,
29+
int k,
30+
int weight_nbit,
31+
int scale_group_size,
32+
bool has_scales,
33+
bool has_bias,
34+
int nr,
35+
int kr,
36+
int sr);
37+
38+
// Packs activations into a kernel-friendly layout.
39+
using pack_activations_fn_type = void (*)(
40+
float* packed_activations,
41+
int m,
42+
int k,
43+
const float* activations,
44+
int mr,
45+
int kr,
46+
int sr);
47+
48+
// Packs weights, scales, and LUTs into the target buffer.
49+
using pack_weights_fn_type = void (*)(
50+
void* packed_weights_ptr,
51+
const uint8_t* weight_qvals_indices,
52+
const float* weight_scales,
53+
const float* weight_luts,
54+
int n,
55+
int k,
56+
int scale_group_size,
57+
int lut_group_size,
58+
bool has_scales,
59+
bool has_bias,
60+
const float* bias,
61+
int nr,
62+
int kr,
63+
int sr);
64+
65+
// Offset in packed_activation buffer for multithread.
66+
using packed_activations_offset_fn_type =
67+
size_t (*)(int m_idx, int k, int mr, int kr, int sr);
68+
69+
// Offset in packed_weight buffer for multithread.
70+
using packed_weights_offset_fn_type = size_t (*)(
71+
int n_idx,
72+
int k,
73+
int weight_nbit,
74+
int scale_group_size,
75+
bool has_scales,
76+
bool has_bias,
77+
int nr,
78+
int kr,
79+
int sr);
80+
81+
// The main computation kernel.
82+
using kernel_fn_type = void (*)(
83+
float* output,
84+
int output_m_stride,
85+
int m,
86+
int n,
87+
int k,
88+
int scale_group_size,
89+
int lut_group_size,
90+
const void* packed_weights,
91+
const void* packed_activations,
92+
float clamp_min,
93+
float clamp_max,
94+
bool has_bias,
95+
bool has_clamp);
96+
97+
// Configuration for a single kernel.
98+
struct config_type {
99+
int m_step{0};
100+
int mr{0};
101+
packed_activations_size_fn_type packed_activations_size{nullptr};
102+
packed_activations_offset_fn_type packed_activations_offset{nullptr};
103+
pack_activations_fn_type pack_activations{nullptr};
104+
kernel_fn_type kernel{nullptr};
105+
};
106+
107+
// Preferred memory alignment for buffers.
108+
size_t preferred_alignment{0};
109+
int n_step{0};
110+
int nr{0};
111+
int kr{0};
112+
int sr{0};
113+
int weight_nbit{0};
114+
bool has_scales{false};
115+
bool has_bias{false};
116+
117+
packed_weights_size_fn_type packed_weights_size{nullptr};
118+
packed_weights_offset_fn_type packed_weights_offset{nullptr};
119+
pack_weights_fn_type pack_weights{nullptr};
120+
121+
std::array<config_type, kMaxConfigs> configs;
122+
123+
static UKernelConfig make(
124+
size_t preferred_alignment,
125+
int n_step,
126+
int nr,
127+
int kr,
128+
int sr,
129+
int weight_nbit,
130+
bool has_scales,
131+
bool has_bias,
132+
packed_weights_size_fn_type packed_weights_size,
133+
packed_weights_offset_fn_type packed_weights_offset,
134+
pack_weights_fn_type pack_weights,
135+
std::array<config_type, kMaxConfigs> configs);
136+
137+
// Validation function to ensure all pointers are properly initialized.
138+
inline void validate() const {
139+
// 1. Validate Top-Level UKernelConfig Parameters
140+
TORCHAO_CHECK(preferred_alignment >= 1, "preferred_alignment must be >= 1");
141+
TORCHAO_CHECK(nr >= 1, "nr must be >= 1");
142+
TORCHAO_CHECK(kr >= 1, "kr must be >= 1");
143+
TORCHAO_CHECK(sr >= 1, "sr must be >= 1");
144+
TORCHAO_CHECK(weight_nbit >= 1, "weight_nbit must be >= 1");
145+
TORCHAO_CHECK(weight_nbit <= 4, "weight_nbit must be <= 4");
146+
TORCHAO_CHECK(
147+
packed_weights_size != nullptr,
148+
"packed_weights_size_fn_type must be set");
149+
TORCHAO_CHECK(
150+
packed_weights_offset != nullptr,
151+
"packed_weights_offset_fn_type must be set");
152+
TORCHAO_CHECK(pack_weights != nullptr, "pack_weights must be set");
153+
154+
// 2. Validate the Array of Linear Configurations
155+
// At least one configuration must be defined.
156+
TORCHAO_CHECK(
157+
!configs.empty(),
158+
"At least one valid kernel configuration must be provided.");
159+
160+
for (size_t i = 0; i < configs.size(); ++i) {
161+
const auto& config = configs[i];
162+
163+
TORCHAO_CHECK(
164+
config.packed_activations_size != nullptr,
165+
"config.packed_activations_size must be set");
166+
TORCHAO_CHECK(
167+
config.pack_activations != nullptr,
168+
"config.pack_activations must be set");
169+
TORCHAO_CHECK(config.kernel != nullptr, "config.kernel must be set");
170+
171+
if (i > 0) {
172+
const auto& prev_config = configs[i - 1];
173+
TORCHAO_CHECK(
174+
prev_config.m_step > 0,
175+
"There cannot be a gap in configurations (m_step=0 followed by m_step>0)");
176+
TORCHAO_CHECK(
177+
prev_config.m_step < config.m_step,
178+
"m_step values in configs must be strictly increasing.");
179+
}
180+
}
181+
}
182+
183+
// Selects the appropriate configuration based on m.
184+
inline int select_config_idx(int m) const {
185+
assert(m >= 1);
186+
assert(configs[0].m_step >= 1);
187+
188+
size_t i = 0;
189+
while (i + 1 < configs.size() && configs[i + 1].m_step >= 1 &&
190+
configs[i + 1].m_step <= m) {
191+
assert(configs[i].m_step < configs[i + 1].m_step);
192+
i++;
193+
}
194+
195+
assert(i < configs.size());
196+
assert(configs[i].m_step >= 1);
197+
assert(i == 0 || configs[i].m_step <= m);
198+
return static_cast<int>(i);
199+
}
200+
};
201+
202+
inline UKernelConfig UKernelConfig::make(
203+
size_t preferred_alignment,
204+
int n_step,
205+
int nr,
206+
int kr,
207+
int sr,
208+
int weight_nbit,
209+
bool has_scales,
210+
bool has_bias,
211+
packed_weights_size_fn_type packed_weights_size,
212+
packed_weights_offset_fn_type packed_weights_with_lut_offset,
213+
pack_weights_fn_type pack_weights,
214+
std::array<config_type, kMaxConfigs> configs) {
215+
return UKernelConfig{
216+
preferred_alignment,
217+
n_step,
218+
nr,
219+
kr,
220+
sr,
221+
weight_nbit,
222+
has_scales,
223+
has_bias,
224+
packed_weights_size,
225+
packed_weights_with_lut_offset,
226+
pack_weights,
227+
std::move(configs)};
228+
}
229+
} // namespace torchao::ops::groupwise_lowbit_weight_lut

0 commit comments

Comments
 (0)