From 84f52f6cbd8d339be68d633b5178b3bd756ef878 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 21 Aug 2024 08:13:17 -0700 Subject: [PATCH] Setup dirsync for torchao experimental (#719) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/719 In order to build and test kernels internally for mac, android and ios we need to add build support in xplat. Which means we need to setup mirroring between fbcode and xplat. For now setting this up only for experimental folder, pending larger decision on whether to sync the two folder entirely. Reviewed By: mzlee Differential Revision: D61480451 --- ...se_lowbit_weight_1x4x16_f32_neondot-impl.h | 16 +- ...se_lowbit_weight_1x8x16_f32_neondot-impl.h | 30 +-- .../cpu/aarch64/reduction/compute_sum.cpp | 5 +- .../aarch64/reduction/find_min_and_max.cpp | 28 +-- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 9 - .../cpu/aarch64/tests/build_and_run_tests.sh | 9 +- .../kernels/cpu/aarch64/tests/test_linear.cpp | 233 ++++++++++-------- 7 files changed, 146 insertions(+), 184 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h index e7fd56cf9b..efcce4bb2e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h @@ -218,21 +218,7 @@ void kernel_impl( if constexpr (has_clamp) { res = clamp(res, clamp_min, clamp_max); } - - // Store result - int remaining = n - n_idx; - float* store_loc = output + m_idx * output_m_stride + n_idx; - if (remaining >= 4) { - vst1q_f32(store_loc, res); - } else if (remaining >= 3) { - vst1_f32(store_loc, vget_low_f32(res)); - *(store_loc + 2) = res[2]; - } else if (remaining >= 2) { - vst1_f32(store_loc, vget_low_f32(res)); - } else { - *(store_loc) = res[0]; - } - + vst1q_f32(output + m_idx * output_m_stride + n_idx, res); } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h index 74ed288044..37f254c983 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h @@ -290,34 +290,8 @@ void kernel_impl( res_0123 = vec_clamp(res_0123, vec_min, vec_max); res_4567 = vec_clamp(res_4567, vec_min, vec_max); } - - // Store result - int remaining = n - n_idx; - float* store_loc = output + m_idx * output_m_stride + n_idx; - if (remaining >= 8) { - vst1q_f32(store_loc, res_0123); - vst1q_f32(store_loc + 4, res_4567); - } else if (remaining >= 7) { - vst1q_f32(store_loc, res_0123); - vst1_f32(store_loc + 4, vget_low_f32(res_4567)); - *(store_loc + 6) = res_4567[2]; - } else if (remaining >= 6) { - vst1q_f32(store_loc, res_0123); - vst1_f32(store_loc + 4, vget_low_f32(res_4567)); - } else if (remaining >= 5) { - vst1q_f32(store_loc, res_0123); - *(store_loc + 4) = res_4567[0]; - } else if (remaining >= 4) { - vst1q_f32(store_loc, res_0123); - } else if (remaining >= 3) { - vst1_f32(store_loc, vget_low_f32(res_0123)); - *(store_loc + 2) = res_0123[2]; - } else if (remaining >= 2) { - vst1_f32(store_loc, vget_low_f32(res_0123)); - } else { - *store_loc = res_0123[0]; - } - + vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123); + vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567); } // n_idx activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); } // m_idx diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp index 4fe0cb2e8f..ab1f26180d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp @@ -1,18 +1,15 @@ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. #include -#include int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( const int8_t* vals, int size) { - assert(size >= 1); - int32_t res = 0; int i = 0; #pragma unroll(4) - for (; i + 15 < size; i += 16) { + for (; i < size; i += 16) { int8x16_t vec_vals = vld1q_s8(vals + i); res += (int)(vaddlvq_s8(vec_vals)); } diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp index d30940ca60..ed7ca01bb4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp @@ -1,33 +1,23 @@ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. #include -#include void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( float32_t& min, float32_t& max, const float32_t* vals, int size) { - assert(size > 0); - - // Needed in case size < 4 so we don't compare to - // uninitialized min/max values - min = vals[0]; - max = min; - + float32x4_t mins = vdupq_n_f32(0.0); + float32x4_t maxes = vdupq_n_f32(0.0); int i = 0; - if (i + 3 < size) { - float32x4_t mins = vld1q_f32(vals + i); - float32x4_t maxes = mins; - i += 4; - for (; i + 3 < size; i += 4) { - float32x4_t v = vld1q_f32(vals + i); - mins = vminq_f32(mins, v); - maxes = vmaxq_f32(maxes, v); - } - min = vminvq_f32(mins); - max = vmaxvq_f32(maxes); + for (; i < size; i += 8) { + float32x4_t v1 = vld1q_f32(vals + i); + float32x4_t v2 = vld1q_f32(vals + i + 4); + mins = vminq_f32(v1, v2); + maxes = vmaxq_f32(v1, v2); } + min = vminvq_f32(mins); + max = vmaxvq_f32(maxes); // Remainder while (i < size) { diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index 4273c16785..1b78f25b9c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -35,14 +35,6 @@ target_link_libraries( dep ) -add_executable(test_reduction test_reduction.cpp) -target_link_libraries( - test_reduction - PRIVATE - GTest::gtest_main - dep -) - add_executable(test_bitpacking test_bitpacking.cpp) target_link_libraries( test_bitpacking @@ -69,7 +61,6 @@ target_link_libraries( include(GoogleTest) gtest_discover_tests(test_quantization) -gtest_discover_tests(test_reduction) gtest_discover_tests(test_bitpacking) gtest_discover_tests(test_linear) gtest_discover_tests(test_valpacking) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 6eddc520eb..308455206c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -7,8 +7,7 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e cmake --build ${CMAKE_OUT} # Run -${CMAKE_OUT}/test_quantization -${CMAKE_OUT}/test_reduction -${CMAKE_OUT}/test_bitpacking -${CMAKE_OUT}/test_linear -${CMAKE_OUT}/test_valpacking + ${CMAKE_OUT}/test_quantization + ${CMAKE_OUT}/test_bitpacking + ${CMAKE_OUT}/test_linear + ${CMAKE_OUT}/test_valpacking diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 39db050c61..4b61c162e0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -10,11 +10,12 @@ float kTol = 0.0001; template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot() { + int m = 7; + int k = 128; + int n = 13; + int group_size = 32; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -49,7 +50,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -71,53 +72,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasClamp) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -152,7 +170,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -174,66 +192,70 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, - NLessThan4) { - for (int n = 1; n < 4; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } template -void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( - int m, - int k, - int n, - int group_size) { +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( m, @@ -268,7 +290,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot test_case.weight_scales.data(), /*weight_zeros=*/test_case.weight_zeros.data()); - std::vector output(m * n); + std::vector output(m * k); kernel( output.data(), /*output_m_stride=*/n, @@ -290,56 +312,59 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); } TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasClamp) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); -} + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; -TEST( - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, - NLessThan8) { - for (int n = 1; n < 8; n++) { - test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( - /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); - } + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); }