|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +#pragma once |
| 8 | + |
| 9 | +#if defined(__aarch64__) || defined(__ARM_NEON) |
| 10 | + |
| 11 | +#include <arm_neon.h> |
| 12 | +#include <torchao/experimental/kernels/cpu/aarch64/macro.h> |
| 13 | + |
| 14 | +namespace torchao::lut { |
| 15 | + |
| 16 | +TORCHAO_ALWAYS_INLINE inline void load_fp32_lut(uint8x16x4_t& lut, const float* table) { |
| 17 | + lut = { |
| 18 | + vld1q_u8((const uint8_t*)&table[0]), |
| 19 | + vld1q_u8((const uint8_t*)&table[4]), |
| 20 | + vld1q_u8((const uint8_t*)&table[8]), |
| 21 | + vld1q_u8((const uint8_t*)&table[12]) |
| 22 | + }; |
| 23 | +} |
| 24 | + |
| 25 | +// This function looks up float values from a 16-value LUT |
| 26 | +// (stored as 16 consecutive floats loaded into uint8x16x4_t) |
| 27 | +// The indices of the 16 values being looked up are contained in idx |
| 28 | +// These values are output to out0, out1, out2, and out3 |
| 29 | +TORCHAO_ALWAYS_INLINE inline void lookup_from_fp32_lut( |
| 30 | + float32x4_t& out0, |
| 31 | + float32x4_t& out1, |
| 32 | + float32x4_t& out2, |
| 33 | + float32x4_t& out3, |
| 34 | + const uint8x16x4_t& lut, |
| 35 | + const uint8x16_t idx |
| 36 | +) { |
| 37 | + // Performs a vectorized lookup of FP32 values from a 16-element float table. |
| 38 | + // The input `idx` is a uint8x16_t vector containing 16 indices (0–15), |
| 39 | + // each selecting a float from the LUT. Since each float is 4 bytes, we compute |
| 40 | + // the byte offsets for each selected float: |
| 41 | + // - `idx0` = idx * 4 (byte 0 of each float) |
| 42 | + // - `idx1` = idx0 + 1 (byte 1) |
| 43 | + // - `idx2` = idx0 + 2 (byte 2) |
| 44 | + // - `idx3` = idx0 + 3 (byte 3) |
| 45 | + // |
| 46 | + // These are grouped into a 4-way NEON table `idx_tbl = {idx0, idx1, idx2, idx3}`. |
| 47 | + // |
| 48 | + // To reconstruct full FP32 values (4 bytes each) from the byte lookup, we use |
| 49 | + // `vqtbl4q_u8(idx_tbl, ...)` with a special interleaving `offsets` vector: |
| 50 | + // - `offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 }` |
| 51 | + // |
| 52 | + // This offset pattern selects the 4 bytes for float0 (0, 16, 32, 48), float1 (1, 17, 33, 49), etc. |
| 53 | + // |
| 54 | + // We repeat this with offset vectors incremented by 4 and 8 and 12 to produce |
| 55 | + // `out1_idx`, `out2_idx`, and `out3_idx`, each forming the byte indices for |
| 56 | + // the next group of 4 floats. |
| 57 | + // |
| 58 | + // Finally, we use `vqtbl4q_u8(lut, outN_idx)` to gather bytes from the original LUT, |
| 59 | + // and `vreinterpretq_f32_u8(...)` to convert the byte-wise result into |
| 60 | + // actual `float32x4_t` values: `out0`, `out1`, `out2`, and `out3` |
| 61 | + |
| 62 | + uint8x16_t idx0 = vshlq_n_u8(idx, 2); |
| 63 | + uint8x16_t idx1 = vaddq_u8(idx0, vdupq_n_u8(1)); |
| 64 | + uint8x16_t idx2 = vaddq_u8(idx0, vdupq_n_u8(2)); |
| 65 | + uint8x16_t idx3 = vaddq_u8(idx0, vdupq_n_u8(3)); |
| 66 | + |
| 67 | + // 4-way interleave idx0, idx1, idx2, idx3 to create out0_idx, out1_idx, out2_idx, out3_idx |
| 68 | + uint8x16x4_t idx_tbl = {idx0, idx1, idx2, idx3}; |
| 69 | + uint8x16_t offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 }; |
| 70 | + uint8x16_t out0_idx = vqtbl4q_u8(idx_tbl, offsets); |
| 71 | + uint8x16_t out1_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(4))); |
| 72 | + uint8x16_t out2_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(8))); |
| 73 | + uint8x16_t out3_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(12))); |
| 74 | + |
| 75 | + out0 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out0_idx)); |
| 76 | + out1 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out1_idx)); |
| 77 | + out2 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out2_idx)); |
| 78 | + out3 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out3_idx)); |
| 79 | +} |
| 80 | + |
| 81 | +} // namespace torchao::lut |
| 82 | + |
| 83 | + |
| 84 | +#endif // defined(__aarch64__) || defined(__ARM_NEON) |
0 commit comments