|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#pragma once |
| 8 | + |
| 9 | +#include <torchao/experimental/ops/packed_weights_header.h> |
| 10 | +#include <stdexcept> |
| 11 | + |
| 12 | +namespace torchao::ops::groupwise_lowbit_weight_lut { |
| 13 | + |
| 14 | +/** |
| 15 | + * @brief Defines the format parameters for the packed weights of the |
| 16 | + * groupwise LUT kernel. |
| 17 | + */ |
| 18 | +struct PackedWeightsFormat { |
| 19 | + torchao::ops::PackedWeightsType type; |
| 20 | + int weight_nbit; |
| 21 | + int scale_group_size; |
| 22 | + int lut_group_size; |
| 23 | + bool has_scales; |
| 24 | + bool has_bias; |
| 25 | + int nr; |
| 26 | + int kr; |
| 27 | + int sr; |
| 28 | + |
| 29 | + PackedWeightsFormat( |
| 30 | + torchao::ops::PackedWeightsType type, |
| 31 | + int weight_nbit, |
| 32 | + int scale_group_size, |
| 33 | + int lut_group_size, |
| 34 | + bool has_scales, |
| 35 | + bool has_bias, |
| 36 | + int nr, |
| 37 | + int kr, |
| 38 | + int sr) |
| 39 | + : type{type}, |
| 40 | + weight_nbit{weight_nbit}, |
| 41 | + scale_group_size{scale_group_size}, |
| 42 | + lut_group_size{lut_group_size}, |
| 43 | + has_scales{has_scales}, |
| 44 | + has_bias{has_bias}, |
| 45 | + nr{nr}, |
| 46 | + kr{kr}, |
| 47 | + sr{sr} {} |
| 48 | + |
| 49 | + /** |
| 50 | + * @brief Converts a generic PackedWeightsHeader into this specific format. |
| 51 | + * |
| 52 | + * This assumes the generic header's `params` array is populated in the |
| 53 | + * correct order. |
| 54 | + */ |
| 55 | + static PackedWeightsFormat from_packed_weights_header( |
| 56 | + const torchao::ops::PackedWeightsHeader& header) { |
| 57 | + return PackedWeightsFormat( |
| 58 | + header.type, |
| 59 | + header.params[0], // weight_nbit |
| 60 | + header.params[1], // scale_group_size |
| 61 | + header.params[2], // lut_group_size |
| 62 | + static_cast<bool>(header.params[3]), // has_scales |
| 63 | + static_cast<bool>(header.params[4]), // has_bias |
| 64 | + header.params[5], // nr |
| 65 | + header.params[6], // kr |
| 66 | + header.params[7], // sr |
| 67 | + ); |
| 68 | + } |
| 69 | + |
| 70 | + /** |
| 71 | + * @brief Converts this specific format into a generic PackedWeightsHeader. |
| 72 | + */ |
| 73 | + inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const { |
| 74 | + return torchao::ops::PackedWeightsHeader( |
| 75 | + type, |
| 76 | + {weight_nbit, |
| 77 | + scale_group_size, |
| 78 | + lut_group_size, |
| 79 | + has_scales, |
| 80 | + has_bias, |
| 81 | + nr, |
| 82 | + kr, |
| 83 | + sr}); |
| 84 | + } |
| 85 | +}; |
| 86 | + |
| 87 | +/** |
| 88 | + * @brief Helper function to validate that the provided format matches the |
| 89 | + * expectations of a specific kernel. |
| 90 | + */ |
| 91 | +inline void check_format( |
| 92 | + const PackedWeightsFormat& format, |
| 93 | + torchao::ops::PackedWeightsType expected_type, |
| 94 | + int expected_weight_nbit) { |
| 95 | + if (format.type != expected_type) { |
| 96 | + throw std::runtime_error( |
| 97 | + "Kernel expects packed_weights type=" + |
| 98 | + std::to_string(static_cast<int>(expected_type)) + |
| 99 | + ", but got packed_weights with type=" + |
| 100 | + std::to_string(static_cast<int>(format.type))); |
| 101 | + } |
| 102 | + if (format.weight_nbit != expected_weight_nbit) { |
| 103 | + throw std::runtime_error( |
| 104 | + "Kernel expects weight_nbit=" + std::to_string(expected_weight_nbit) + |
| 105 | + ", but got packed_weights with weight_nbit=" + |
| 106 | + std::to_string(format.weight_nbit)); |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +} // namespace torchao::ops::groupwise_lowbit_weight_lut |
0 commit comments