Skip to content

Commit 1909171

Browse files
authored
Add high-level operator interface
Differential Revision: D60321449 Pull Request resolved: #708
1 parent 0991ba9 commit 1909171

22 files changed

+1759
-4
lines changed

torchao/experimental/kernels/cpu/build_and_run_benchmarks.sh renamed to torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}
33

44
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
5-
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../..
5+
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
66
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks
77
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
88
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \

torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h>
66
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
77
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
8+
#include <cassert>
89

910
namespace torchao {
1011
namespace bitpacking {

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
77
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
88
#include <cassert>
9+
#include <cstring>
910

1011
namespace torchao::kernels::cpu::aarch64::linear {
1112
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::
@@ -251,7 +252,7 @@ int inline weight_data_size_impl(
251252
}
252253

253254
// Replace n with next multiple of 4 >= n
254-
n = ((n + 3) >> 2) << 2;
255+
n = ((n + 3) / 4) * 4;
255256

256257
return col_size * n;
257258
}

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
77
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
88
#include <cassert>
9+
#include <cstring>
910

1011
namespace torchao::kernels::cpu::aarch64::linear {
1112
namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::
@@ -324,7 +325,7 @@ int inline weight_data_size_impl(
324325
}
325326

326327
// Replace n with next multiple of 8 >= n
327-
n = ((n + 3) >> 3) << 3;
328+
n = ((n + 7) / 8) * 8;
328329

329330
return col_size * n;
330331
}

torchao/experimental/kernels/cpu/build_and_run_tests.sh renamed to torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
3-
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../..
3+
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
44
export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests
55
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT}
66

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44
#include <torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h>
55
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
6+
#include <cassert>
67
#include <functional>
78
#include <random>
89
#include <vector>
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
cmake_minimum_required(VERSION 3.19)
4+
project(benchmarks)
5+
set(CMAKE_CXX_STANDARD 17)
6+
set(CMAKE_BUILD_TYPE Release)
7+
8+
include(FetchContent)
9+
FetchContent_Declare(googlebenchmark
10+
GIT_REPOSITORY https://github.com/google/benchmark.git
11+
GIT_TAG main) # need main for benchmark::benchmark
12+
13+
set(BENCHMARK_ENABLE_TESTING OFF)
14+
FetchContent_MakeAvailable(
15+
googlebenchmark)
16+
17+
add_compile_options("-Wall" "-Werror")
18+
19+
include(CMakePrintHelpers)
20+
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
21+
include_directories(${TORCHAO_LIBRARIES})
22+
23+
add_library(
24+
dep
25+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
26+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
27+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
28+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
29+
)
30+
31+
add_executable(benchmark_linear_operator benchmark_linear_operator.cpp)
32+
target_link_libraries(
33+
benchmark_linear_operator
34+
PRIVATE
35+
benchmark::benchmark
36+
dep
37+
)
38+
39+
option(TORCHAO_PARALLEL_OMP "" OFF)
40+
option(TORCHAO_PARALLEL_SINGLE_THREADED "" ON)
41+
42+
if (TORCHAO_PARALLEL_OMP)
43+
message("OpenMP_ROOT: ${OpenMP_ROOT}")
44+
add_definitions(-DTORCHAO_PARALLEL_OMP=1)
45+
find_package(OpenMP REQUIRED)
46+
if(OpenMP_CXX_FOUND)
47+
target_link_libraries(benchmark_linear_operator PUBLIC OpenMP::OpenMP_CXX)
48+
endif()
49+
endif()
50+
51+
if (TORCHAO_PARALLEL_SINGLE_THREADED)
52+
add_definitions(-DTORCHAO_PARALLEL_SINGLE_THREADED=1)
53+
endif()
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include <benchmark/benchmark.h>
4+
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
5+
#include <torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h>
6+
#include <torchao/experimental/kernels/cpu/memory.h>
7+
#include <vector>
8+
9+
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
10+
static void channelwise_8bit_activation_groupwise_lowbit_weight(
11+
benchmark::State& state) {
12+
int m = state.range(0);
13+
int n = state.range(1);
14+
int k = state.range(2);
15+
int group_size = state.range(3);
16+
int num_threads = state.range(4);
17+
18+
// OMP appears to cache when repeating the same task in the benchmark
19+
// To prevent this, we benchmark a number of tasks
20+
int num_test_cases = state.range(5);
21+
22+
// Initialize config and tiling params
23+
using namespace torchao::operators::cpu::linear::
24+
channelwise_8bit_activation_groupwise_lowbit_weight;
25+
26+
auto ukernel_config =
27+
get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>();
28+
auto pack_weight_data_tiling_params =
29+
get_default_pack_weight_data_tiling_params(ukernel_config, n);
30+
auto linear_tiling_params =
31+
get_default_linear_tiling_params(ukernel_config, m, n);
32+
auto linear_scheduling_policy =
33+
LinearTileSchedulingPolicy::single_mc_parallel_nc;
34+
35+
// Set number of threads
36+
torchao::set_num_threads(num_threads);
37+
assert(num_threads == torchao::get_num_threads());
38+
39+
// Generate test cases
40+
std::vector<
41+
torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case>
42+
test_cases;
43+
for (int i = 0; i < num_test_cases; ++i) {
44+
test_cases.emplace_back(
45+
torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case::
46+
generate(
47+
m,
48+
k,
49+
n,
50+
group_size,
51+
weight_nbit,
52+
has_weight_zeros,
53+
has_bias,
54+
has_clamp));
55+
}
56+
57+
// Pack test case weights
58+
size_t packed_weight_data_size =
59+
get_packed_weight_data_size(ukernel_config, n, k, group_size);
60+
size_t packed_weight_data_alignment =
61+
get_packed_weight_data_alignment(ukernel_config);
62+
63+
std::vector<std::unique_ptr<char[], void (*)(void*)>> packed_weight_data;
64+
for (int i = 0; i < test_cases.size(); i++) {
65+
packed_weight_data.emplace_back(torchao::make_aligned_byte_array_unique_ptr(
66+
packed_weight_data_alignment, packed_weight_data_size));
67+
pack_weight_data_operator(
68+
ukernel_config,
69+
pack_weight_data_tiling_params,
70+
packed_weight_data[i].get(),
71+
n,
72+
k,
73+
group_size,
74+
test_cases[i].weight_qvals.data(),
75+
test_cases[i].weight_scales.data(),
76+
test_cases[i].weight_zeros.data());
77+
}
78+
79+
// Allocate activation data buffer for test cases
80+
size_t activation_data_buffer_size = get_activation_data_buffer_size(
81+
ukernel_config,
82+
linear_tiling_params,
83+
linear_scheduling_policy,
84+
m,
85+
k,
86+
group_size);
87+
size_t activation_data_buffer_alignment =
88+
get_activation_data_buffer_alignment(ukernel_config);
89+
90+
auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr(
91+
activation_data_buffer_alignment, activation_data_buffer_size);
92+
93+
auto output = std::vector<float>(m * n);
94+
for (auto _ : state) {
95+
for (int i = 0; i < test_cases.size(); i++) {
96+
linear_operator(
97+
ukernel_config,
98+
linear_tiling_params,
99+
linear_scheduling_policy,
100+
activation_data_buffer.get(),
101+
output.data(),
102+
m,
103+
n,
104+
k,
105+
group_size,
106+
packed_weight_data[i].get(),
107+
test_cases[i].activations.data(),
108+
test_cases[i].bias.data(),
109+
test_cases[i].clamp_min,
110+
test_cases[i].clamp_max);
111+
}
112+
}
113+
}
114+
115+
#define BENCHMARK_PARAMS \
116+
{ \
117+
/*m*/ {1}, /*n*/ {4096}, /*k*/ {4096}, /*group_size*/ {16, 32, 256}, \
118+
/*num_threads*/ {1, 2, 4, 6, 8}, /*num_test_cases*/ { \
119+
10 \
120+
} \
121+
}
122+
123+
#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT( \
124+
weight_nbit) \
125+
BENCHMARK(channelwise_8bit_activation_groupwise_lowbit_weight< \
126+
weight_nbit, \
127+
false /*has_weight_zeros*/, \
128+
false /*has_bias*/, \
129+
false /*has_clamp*/>) \
130+
->ArgsProduct(BENCHMARK_PARAMS) \
131+
->ArgNames( \
132+
{"m", "n", "k", "group_size", "num_threads", "num_test_cases"});
133+
134+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(3);
135+
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT(4);
136+
137+
// Run the benchmark
138+
BENCHMARK_MAIN();
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK}
5+
6+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
7+
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../..
8+
export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks
9+
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
10+
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/benchmarks \
11+
-B ${CMAKE_OUT} \
12+
-DOpenMP_ROOT=$(brew --prefix libomp) \
13+
-DTORCHAO_PARALLEL_OMP=ON
14+
15+
cmake --build ${CMAKE_OUT}
16+
17+
# Run
18+
case "$1" in
19+
linear_operator) ${CMAKE_OUT}/benchmark_linear_operator; ;;
20+
*) echo "Unknown benchmark: $1. Please specify one of: linear_operator."; exit 1; ;;
21+
esac

0 commit comments

Comments
 (0)