Skip to content

Commit c8d3e93

Browse files
authored
Add packed weight format for LUT based low bit quantization.
Differential Revision: D77615431 Pull Request resolved: #2530
1 parent 0da89e4 commit c8d3e93

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torchao/experimental/ops/packed_weights_header.h>
10+
#include <stdexcept>
11+
12+
namespace torchao::ops::groupwise_lowbit_weight_lut {
13+
14+
/**
15+
* @brief Defines the format parameters for the packed weights of the
16+
* groupwise LUT kernel.
17+
*/
18+
struct PackedWeightsFormat {
19+
torchao::ops::PackedWeightsType type;
20+
int weight_nbit;
21+
int scale_group_size;
22+
int lut_group_size;
23+
bool has_scales;
24+
bool has_bias;
25+
int nr;
26+
int kr;
27+
int sr;
28+
29+
PackedWeightsFormat(
30+
torchao::ops::PackedWeightsType type,
31+
int weight_nbit,
32+
int scale_group_size,
33+
int lut_group_size,
34+
bool has_scales,
35+
bool has_bias,
36+
int nr,
37+
int kr,
38+
int sr)
39+
: type{type},
40+
weight_nbit{weight_nbit},
41+
scale_group_size{scale_group_size},
42+
lut_group_size{lut_group_size},
43+
has_scales{has_scales},
44+
has_bias{has_bias},
45+
nr{nr},
46+
kr{kr},
47+
sr{sr} {}
48+
49+
/**
50+
* @brief Converts a generic PackedWeightsHeader into this specific format.
51+
*
52+
* This assumes the generic header's `params` array is populated in the
53+
* correct order.
54+
*/
55+
static PackedWeightsFormat from_packed_weights_header(
56+
const torchao::ops::PackedWeightsHeader& header) {
57+
return PackedWeightsFormat(
58+
header.type,
59+
header.params[0], // weight_nbit
60+
header.params[1], // scale_group_size
61+
header.params[2], // lut_group_size
62+
static_cast<bool>(header.params[3]), // has_scales
63+
static_cast<bool>(header.params[4]), // has_bias
64+
header.params[5], // nr
65+
header.params[6], // kr
66+
header.params[7], // sr
67+
);
68+
}
69+
70+
/**
71+
* @brief Converts this specific format into a generic PackedWeightsHeader.
72+
*/
73+
inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const {
74+
return torchao::ops::PackedWeightsHeader(
75+
type,
76+
{weight_nbit,
77+
scale_group_size,
78+
lut_group_size,
79+
has_scales,
80+
has_bias,
81+
nr,
82+
kr,
83+
sr});
84+
}
85+
};
86+
87+
/**
88+
* @brief Helper function to validate that the provided format matches the
89+
* expectations of a specific kernel.
90+
*/
91+
inline void check_format(
92+
const PackedWeightsFormat& format,
93+
torchao::ops::PackedWeightsType expected_type,
94+
int expected_weight_nbit) {
95+
if (format.type != expected_type) {
96+
throw std::runtime_error(
97+
"Kernel expects packed_weights type=" +
98+
std::to_string(static_cast<int>(expected_type)) +
99+
", but got packed_weights with type=" +
100+
std::to_string(static_cast<int>(format.type)));
101+
}
102+
if (format.weight_nbit != expected_weight_nbit) {
103+
throw std::runtime_error(
104+
"Kernel expects weight_nbit=" + std::to_string(expected_weight_nbit) +
105+
", but got packed_weights with weight_nbit=" +
106+
std::to_string(format.weight_nbit));
107+
}
108+
}
109+
110+
} // namespace torchao::ops::groupwise_lowbit_weight_lut

0 commit comments

Comments
 (0)