Skip to content

Commit 227d4bf

Browse files
authored
Bug fixes
Differential Revision: D60773448 Pull Request resolved: #717
1 parent 2c8e3f3 commit 227d4bf

10 files changed

+246
-148
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,21 @@ void kernel_impl(
218218
if constexpr (has_clamp) {
219219
res = clamp(res, clamp_min, clamp_max);
220220
}
221-
vst1q_f32(output + m_idx * output_m_stride + n_idx, res);
221+
222+
// Store result
223+
int remaining = n - n_idx;
224+
float* store_loc = output + m_idx * output_m_stride + n_idx;
225+
if (remaining >= 4) {
226+
vst1q_f32(store_loc, res);
227+
} else if (remaining >= 3) {
228+
vst1_f32(store_loc, vget_low_f32(res));
229+
*(store_loc + 2) = res[2];
230+
} else if (remaining >= 2) {
231+
vst1_f32(store_loc, vget_low_f32(res));
232+
} else {
233+
*(store_loc) = res[0];
234+
}
235+
222236
} // n_idx
223237
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
224238
} // m_idx

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,34 @@ void kernel_impl(
290290
res_0123 = vec_clamp(res_0123, vec_min, vec_max);
291291
res_4567 = vec_clamp(res_4567, vec_min, vec_max);
292292
}
293-
vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123);
294-
vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567);
293+
294+
// Store result
295+
int remaining = n - n_idx;
296+
float* store_loc = output + m_idx * output_m_stride + n_idx;
297+
if (remaining >= 8) {
298+
vst1q_f32(store_loc, res_0123);
299+
vst1q_f32(store_loc + 4, res_4567);
300+
} else if (remaining >= 7) {
301+
vst1q_f32(store_loc, res_0123);
302+
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
303+
*(store_loc + 6) = res_4567[2];
304+
} else if (remaining >= 6) {
305+
vst1q_f32(store_loc, res_0123);
306+
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
307+
} else if (remaining >= 5) {
308+
vst1q_f32(store_loc, res_0123);
309+
*(store_loc + 4) = res_4567[0];
310+
} else if (remaining >= 4) {
311+
vst1q_f32(store_loc, res_0123);
312+
} else if (remaining >= 3) {
313+
vst1_f32(store_loc, vget_low_f32(res_0123));
314+
*(store_loc + 2) = res_0123[2];
315+
} else if (remaining >= 2) {
316+
vst1_f32(store_loc, vget_low_f32(res_0123));
317+
} else {
318+
*store_loc = res_0123[0];
319+
}
320+
295321
} // n_idx
296322
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
297323
} // m_idx

torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

33
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
4+
#include <cassert>
45

56
int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum(
67
const int8_t* vals,
78
int size) {
9+
assert(size >= 1);
10+
811
int32_t res = 0;
912
int i = 0;
1013

1114
#pragma unroll(4)
12-
for (; i < size; i += 16) {
15+
for (; i + 15 < size; i += 16) {
1316
int8x16_t vec_vals = vld1q_s8(vals + i);
1417
res += (int)(vaddlvq_s8(vec_vals));
1518
}

torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

33
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
4+
#include <cassert>
45

56
void torchao::kernels::cpu::aarch64::reduction::find_min_and_max(
67
float32_t& min,
78
float32_t& max,
89
const float32_t* vals,
910
int size) {
10-
float32x4_t mins = vdupq_n_f32(0.0);
11-
float32x4_t maxes = vdupq_n_f32(0.0);
11+
assert(size > 0);
12+
13+
// Needed in case size < 4 so we don't compare to
14+
// uninitialized min/max values
15+
min = vals[0];
16+
max = min;
17+
1218
int i = 0;
13-
for (; i < size; i += 8) {
14-
float32x4_t v1 = vld1q_f32(vals + i);
15-
float32x4_t v2 = vld1q_f32(vals + i + 4);
16-
mins = vminq_f32(v1, v2);
17-
maxes = vmaxq_f32(v1, v2);
19+
if (i + 3 < size) {
20+
float32x4_t mins = vld1q_f32(vals + i);
21+
float32x4_t maxes = mins;
22+
i += 4;
23+
for (; i + 3 < size; i += 4) {
24+
float32x4_t v = vld1q_f32(vals + i);
25+
mins = vminq_f32(mins, v);
26+
maxes = vmaxq_f32(maxes, v);
27+
}
28+
min = vminvq_f32(mins);
29+
max = vmaxvq_f32(maxes);
1830
}
19-
min = vminvq_f32(mins);
20-
max = vmaxvq_f32(maxes);
2131

2232
// Remainder
2333
while (i < size) {

torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ target_link_libraries(
3535
dep
3636
)
3737

38+
add_executable(test_reduction test_reduction.cpp)
39+
target_link_libraries(
40+
test_reduction
41+
PRIVATE
42+
GTest::gtest_main
43+
dep
44+
)
45+
3846
add_executable(test_bitpacking test_bitpacking.cpp)
3947
target_link_libraries(
4048
test_bitpacking
@@ -61,6 +69,7 @@ target_link_libraries(
6169

6270
include(GoogleTest)
6371
gtest_discover_tests(test_quantization)
72+
gtest_discover_tests(test_reduction)
6473
gtest_discover_tests(test_bitpacking)
6574
gtest_discover_tests(test_linear)
6675
gtest_discover_tests(test_valpacking)

torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e
77
cmake --build ${CMAKE_OUT}
88

99
# Run
10-
${CMAKE_OUT}/test_quantization
11-
${CMAKE_OUT}/test_bitpacking
12-
${CMAKE_OUT}/test_linear
13-
${CMAKE_OUT}/test_valpacking
10+
${CMAKE_OUT}/test_quantization
11+
${CMAKE_OUT}/test_reduction
12+
${CMAKE_OUT}/test_bitpacking
13+
${CMAKE_OUT}/test_linear
14+
${CMAKE_OUT}/test_valpacking

0 commit comments

Comments
 (0)