diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h index e7fd56cf9b..efcce4bb2e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h @@ -218,21 +218,7 @@ void kernel_impl( if constexpr (has_clamp) { res = clamp(res, clamp_min, clamp_max); } - - // Store result - int remaining = n - n_idx; - float* store_loc = output + m_idx * output_m_stride + n_idx; - if (remaining >= 4) { - vst1q_f32(store_loc, res); - } else if (remaining >= 3) { - vst1_f32(store_loc, vget_low_f32(res)); - *(store_loc + 2) = res[2]; - } else if (remaining >= 2) { - vst1_f32(store_loc, vget_low_f32(res)); - } else { - *(store_loc) = res[0]; - } - + vst1q_f32(output + m_idx * output_m_stride + n_idx, res); } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h index 74ed288044..37f254c983 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h @@ -290,34 +290,8 @@ void kernel_impl( res_0123 = vec_clamp(res_0123, vec_min, vec_max); res_4567 = vec_clamp(res_4567, vec_min, vec_max); } - - // Store result - int remaining = n - n_idx; - float* store_loc = output + m_idx * output_m_stride + n_idx; - if (remaining >= 8) { - vst1q_f32(store_loc, res_0123); - vst1q_f32(store_loc + 4, res_4567); - } else if (remaining >= 7) { - vst1q_f32(store_loc, res_0123); - vst1_f32(store_loc + 4, vget_low_f32(res_4567)); - *(store_loc + 6) = res_4567[2]; - } else if (remaining >= 6) { - vst1q_f32(store_loc, res_0123); - vst1_f32(store_loc + 4, vget_low_f32(res_4567)); - } else if (remaining >= 5) { - vst1q_f32(store_loc, res_0123); - *(store_loc + 4) = res_4567[0]; - } else if (remaining >= 4) { - vst1q_f32(store_loc, res_0123); - } else if (remaining >= 3) { - vst1_f32(store_loc, vget_low_f32(res_0123)); - *(store_loc + 2) = res_0123[2]; - } else if (remaining >= 2) { - vst1_f32(store_loc, vget_low_f32(res_0123)); - } else { - *store_loc = res_0123[0]; - } - + vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123); + vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567); } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp index 4fe0cb2e8f..ab1f26180d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp @@ -1,18 +1,15 @@ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. #include -#include int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( const int8_t* vals, int size) { - assert(size >= 1); - int32_t res = 0; int i = 0; #pragma unroll(4) - for (; i + 15 < size; i += 16) { + for (; i < size; i += 16) { int8x16_t vec_vals = vld1q_s8(vals + i); res += (int)(vaddlvq_s8(vec_vals)); } diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp index d30940ca60..ed7ca01bb4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp @@ -1,33 +1,23 @@ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. #include -#include void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( float32_t& min, float32_t& max, const float32_t* vals, int size) { - assert(size > 0); - - // Needed in case size < 4 so we don't compare to - // uninitialized min/max values - min = vals[0]; - max = min; - + float32x4_t mins = vdupq_n_f32(0.0); + float32x4_t maxes = vdupq_n_f32(0.0); int i = 0; - if (i + 3 < size) { - float32x4_t mins = vld1q_f32(vals + i); - float32x4_t maxes = mins; - i += 4; - for (; i + 3 < size; i += 4) { - float32x4_t v = vld1q_f32(vals + i); - mins = vminq_f32(mins, v); - maxes = vmaxq_f32(maxes, v); - } - min = vminvq_f32(mins); - max = vmaxvq_f32(maxes); + for (; i < size; i += 8) { + float32x4_t v1 = vld1q_f32(vals + i); + float32x4_t v2 = vld1q_f32(vals + i + 4); + mins = vminq_f32(v1, v2); + maxes = vmaxq_f32(v1, v2); } + min = vminvq_f32(mins); + max = vmaxvq_f32(maxes); // Remainder while (i < size) { diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 4273c16785..1b78f25b9c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -35,14 +35,6 @@ target_link_libraries( dep ) -add_executable(test_reduction test_reduction.cpp) -target_link_libraries( - test_reduction - PRIVATE - GTest::gtest_main - dep -) - add_executable(test_bitpacking test_bitpacking.cpp) target_link_libraries( test_bitpacking @@ -69,7 +61,6 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(test_quantization) -gtest_discover_tests(test_reduction) gtest_discover_tests(test_bitpacking) gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 6eddc520eb..308455206c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -7,8 +7,7 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e cmake --build ${CMAKE_OUT} # Run -${CMAKE_OUT}/test_quantization -${CMAKE_OUT}/test_reduction -${CMAKE_OUT}/test_bitpacking -${CMAKE_OUT}/test_linear -${CMAKE_OUT}/test_valpacking + ${CMAKE_OUT}/test_quantization + ${CMAKE_OUT}/test_bitpacking + ${CMAKE_OUT}/test_linear + ${CMAKE_OUT}/test_valpacking diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 39db050c61..4b61c162e0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -10,11 +10,12 @@ float kTol = 0.0001; template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot() { + int m = 7; + int k = 128; + int n = 13; + int group_size = 32; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -49,7 +50,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -71,53 +72,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasClamp) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -152,7 +170,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -174,66 +192,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - NLessThan4) { - for (int n = 1; n < 4; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -268,7 +290,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -290,56 +312,59 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - NLessThan8) { - for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); }