Skip to content

Commit 9a565c2

Browse files
authored
Add FP32 LUT lookup functions
Differential Revision: D76926009 Pull Request resolved: #2472
1 parent 6821971 commit 9a565c2

File tree

4 files changed

+130
-0
lines changed

4 files changed

+130
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
11+
#include <arm_neon.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
13+
14+
namespace torchao::lut {
15+
16+
TORCHAO_ALWAYS_INLINE inline void load_fp32_lut(uint8x16x4_t& lut, const float* table) {
17+
lut = {
18+
vld1q_u8((const uint8_t*)&table[0]),
19+
vld1q_u8((const uint8_t*)&table[4]),
20+
vld1q_u8((const uint8_t*)&table[8]),
21+
vld1q_u8((const uint8_t*)&table[12])
22+
};
23+
}
24+
25+
// This function looks up float values from a 16-value LUT
26+
// (stored as 16 consecutive floats loaded into uint8x16x4_t)
27+
// The indices of the 16 values being looked up are contained in idx
28+
// These values are output to out0, out1, out2, and out3
29+
TORCHAO_ALWAYS_INLINE inline void lookup_from_fp32_lut(
30+
float32x4_t& out0,
31+
float32x4_t& out1,
32+
float32x4_t& out2,
33+
float32x4_t& out3,
34+
const uint8x16x4_t& lut,
35+
const uint8x16_t idx
36+
) {
37+
// Performs a vectorized lookup of FP32 values from a 16-element float table.
38+
// The input `idx` is a uint8x16_t vector containing 16 indices (0–15),
39+
// each selecting a float from the LUT. Since each float is 4 bytes, we compute
40+
// the byte offsets for each selected float:
41+
// - `idx0` = idx * 4 (byte 0 of each float)
42+
// - `idx1` = idx0 + 1 (byte 1)
43+
// - `idx2` = idx0 + 2 (byte 2)
44+
// - `idx3` = idx0 + 3 (byte 3)
45+
//
46+
// These are grouped into a 4-way NEON table `idx_tbl = {idx0, idx1, idx2, idx3}`.
47+
//
48+
// To reconstruct full FP32 values (4 bytes each) from the byte lookup, we use
49+
// `vqtbl4q_u8(idx_tbl, ...)` with a special interleaving `offsets` vector:
50+
// - `offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 }`
51+
//
52+
// This offset pattern selects the 4 bytes for float0 (0, 16, 32, 48), float1 (1, 17, 33, 49), etc.
53+
//
54+
// We repeat this with offset vectors incremented by 4 and 8 and 12 to produce
55+
// `out1_idx`, `out2_idx`, and `out3_idx`, each forming the byte indices for
56+
// the next group of 4 floats.
57+
//
58+
// Finally, we use `vqtbl4q_u8(lut, outN_idx)` to gather bytes from the original LUT,
59+
// and `vreinterpretq_f32_u8(...)` to convert the byte-wise result into
60+
// actual `float32x4_t` values: `out0`, `out1`, `out2`, and `out3`
61+
62+
uint8x16_t idx0 = vshlq_n_u8(idx, 2);
63+
uint8x16_t idx1 = vaddq_u8(idx0, vdupq_n_u8(1));
64+
uint8x16_t idx2 = vaddq_u8(idx0, vdupq_n_u8(2));
65+
uint8x16_t idx3 = vaddq_u8(idx0, vdupq_n_u8(3));
66+
67+
// 4-way interleave idx0, idx1, idx2, idx3 to create out0_idx, out1_idx, out2_idx, out3_idx
68+
uint8x16x4_t idx_tbl = {idx0, idx1, idx2, idx3};
69+
uint8x16_t offsets = { 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51 };
70+
uint8x16_t out0_idx = vqtbl4q_u8(idx_tbl, offsets);
71+
uint8x16_t out1_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(4)));
72+
uint8x16_t out2_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(8)));
73+
uint8x16_t out3_idx = vqtbl4q_u8(idx_tbl, vaddq_u8(offsets, vdupq_n_u8(12)));
74+
75+
out0 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out0_idx));
76+
out1 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out1_idx));
77+
out2 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out2_idx));
78+
out3 = vreinterpretq_f32_u8(vqtbl4q_u8(lut, out3_idx));
79+
}
80+
81+
} // namespace torchao::lut
82+
83+
84+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ target_link_libraries(
120120
dep
121121
)
122122

123+
add_executable(test_lut test_lut.cpp)
124+
target_link_libraries(
125+
test_lut
126+
PRIVATE
127+
GTest::gtest_main
128+
dep
129+
)
130+
123131
include(GoogleTest)
124132
gtest_discover_tests(test_quantization)
125133
gtest_discover_tests(test_reduction)
@@ -128,3 +136,4 @@ gtest_discover_tests(test_linear)
128136
gtest_discover_tests(test_embedding)
129137
gtest_discover_tests(test_weight_packing)
130138
gtest_discover_tests(test_qmatmul)
139+
gtest_discover_tests(test_lut)

torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ ${CMAKE_OUT}/test_linear
6161
${CMAKE_OUT}/test_embedding
6262
${CMAKE_OUT}/test_weight_packing
6363
${CMAKE_OUT}/test_qmatmul
64+
${CMAKE_OUT}/test_lut
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#if defined(__aarch64__) || defined(__ARM_NEON)
8+
9+
#include <arm_neon.h>
10+
#include <gtest/gtest.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/lut/lut.h>
13+
#include <vector>
14+
15+
16+
TEST(test_fp32_lut, LutLookup) {
17+
auto lut = torchao::get_random_vector(16, -1.0, 1.0);
18+
auto idx = torchao::get_random_lowbit_vector(16, 4);
19+
20+
uint8x16_t idx_vec = vld1q_u8(idx.data());
21+
uint8x16x4_t lut_vec;
22+
torchao::lut::load_fp32_lut(lut_vec, lut.data());
23+
24+
float32x4_t out0, out1, out2, out3;
25+
torchao::lut::lookup_from_fp32_lut(out0, out1, out2, out3, lut_vec, idx_vec);
26+
27+
for (int i = 0; i < 4; ++i) {
28+
EXPECT_EQ(out0[i], lut[idx[i]]);
29+
EXPECT_EQ(out1[i], lut[idx[i + 4]]);
30+
EXPECT_EQ(out2[i], lut[idx[i + 8]]);
31+
EXPECT_EQ(out3[i], lut[idx[i + 12]]);
32+
}
33+
}
34+
35+
36+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)