Skip to content

Commit 3e9746c

Browse files
authored
Add ATEN parallel backend
Differential Revision: D62303555 Pull Request resolved: #857
1 parent 00d18ec commit 3e9746c

14 files changed

+222
-191
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
function(target_link_torchao_parallel_backend target_name torchao_parallel_backend)
8+
string(TOUPPER ${torchao_parallel_backend} TORCHAO_PARALLEL_BACKEND_TOUPPER)
9+
if(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "ATEN_OPENMP")
10+
message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=ATEN_OPENMP")
11+
12+
set(_OMP_CXX_COMPILE_FLAGS "-fopenmp")
13+
if (APPLE)
14+
set(_OMP_CXX_COMPILE_FLAGS "-Xclang -fopenmp")
15+
endif()
16+
17+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${_OMP_CXX_COMPILE_FLAGS}" PARENT_SCOPE)
18+
19+
find_package(Torch REQUIRED)
20+
include_directories("${TORCH_INCLUDE_DIRS}")
21+
target_link_libraries(${target_name} PRIVATE "${TORCH_LIBRARIES}")
22+
23+
target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_ATEN=1 AT_PARALLEL_OPENMP=1 INTRA_OP_PARALLEL=1)
24+
target_link_libraries(${target_name} PRIVATE ${TORCH_INSTALL_PREFIX}/lib/libomp${CMAKE_SHARED_LIBRARY_SUFFIX})
25+
26+
elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "OPENMP")
27+
message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=OPENMP. You must set the CMake variable OpenMP_ROOT to the OMP library location before compiling. Do not use this option if Torch was built with OPENMP; use ATEN_OPENMP instead.")
28+
find_package(OpenMP REQUIRED)
29+
target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_OPENMP=1)
30+
target_link_libraries(${target_name} PRIVATE OpenMP::OpenMP_CXX)
31+
32+
elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "PTHREADPOOL")
33+
message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=PTHREADPOOL")
34+
include(FetchContent)
35+
FetchContent_Declare(pthreadpool
36+
GIT_REPOSITORY https://github.com/Maratyszcza/pthreadpool.git
37+
GIT_TAG master)
38+
39+
FetchContent_MakeAvailable(
40+
pthreadpool)
41+
42+
target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_PTHREADPOOL=1)
43+
target_link_libraries(${target_name} PRIVATE pthreadpool)
44+
45+
elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "SINGLE_THREADED")
46+
message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=SINGLE_THREADED")
47+
target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_SINGLE_THREADED=1)
48+
49+
elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "TEST_DUMMY")
50+
message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=TEST_DUMMY")
51+
target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_TEST_DUMMY=1)
52+
53+
else()
54+
message(FATAL_ERROR "Unknown TORCHAO_PARALLEL_BACKEND: ${TORCHAO_PARALLEL_BACKEND}. Please choose one of: aten_openmp, openmp, pthreadpool, single_threaded.")
55+
endif()
56+
endfunction()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
add_library(
8+
kernel_aarch64
9+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
10+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
11+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
12+
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
13+
)

torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,8 @@ void pack_weight_data_operator(
5757
int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr);
5858
int num_nc_panels = (n + nc - 1) / nc;
5959

60-
torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
61-
// TODO(T200106949): decide how to handle at::parallel_for not respecting
62-
// user-supplied grain_size
63-
assert(end == begin + 1);
64-
65-
int nc_tile_idx = begin;
60+
torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) {
61+
int nc_tile_idx = idx;
6662
int n_idx = nc_tile_idx * nc;
6763
int nc_tile_size = std::min(nc, n - n_idx);
6864

@@ -178,12 +174,8 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
178174
group_size,
179175
activations + activations_offset);
180176

181-
torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
182-
// TODO(T200106949): decide how to handle at::parallel_for not respecting
183-
// user-supplied grain_size
184-
assert(end == begin + 1);
185-
186-
int nc_tile_idx = begin;
177+
torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) {
178+
int nc_tile_idx = idx;
187179
int n_idx = nc_tile_idx * nc;
188180
int nc_tile_size = std::min(nc, n - n_idx);
189181

@@ -234,8 +226,8 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
234226
int activation_data_size =
235227
ukernel_config.activation_data_size_fn(mr, k, group_size);
236228

237-
torchao::parallel_for(0, num_mc_panels, 1, [&](int64_t begin, int64_t end) {
238-
int mc_tile_idx = begin;
229+
torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) {
230+
int mc_tile_idx = idx;
239231
int m_idx = mc_tile_idx * mc;
240232
int mc_tile_size = std::min(mc, m - m_idx);
241233
int activations_offset = m_idx * k;
@@ -249,34 +241,33 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
249241
activations + activations_offset);
250242
});
251243

252-
torchao::parallel_for(
253-
0, num_mc_panels * num_nc_panels, 1, [&](int64_t begin, int64_t end) {
254-
int mc_tile_idx = begin / num_nc_panels;
255-
int m_idx = mc_tile_idx * mc;
256-
int mc_tile_size = std::min(mc, m - m_idx);
257-
258-
int nc_tile_idx = begin % num_nc_panels;
259-
int n_idx = nc_tile_idx * nc;
260-
int nc_tile_size = std::min(nc, n - n_idx);
261-
262-
int activation_data_offset = (m_idx / mr) * activation_data_size;
263-
int output_offset = m_idx * n + n_idx;
264-
int weight_data_offset = (n_idx / nr) * weight_data_size;
265-
int bias_offset = m_idx;
266-
267-
ukernel_config.kernel_fn(
268-
output + output_offset,
269-
/*output_m_stride=*/n,
270-
/*m=*/mc_tile_size,
271-
/*n=*/nc_tile_size,
272-
k,
273-
group_size,
274-
/*weight_data=*/(char*)weight_data + weight_data_offset,
275-
/*activation_data=*/activation_data_buffer + activation_data_offset,
276-
/*bias=*/bias + bias_offset,
277-
clamp_min,
278-
clamp_max);
279-
});
244+
torchao::parallel_1d(0, num_mc_panels * num_nc_panels, [&](int64_t idx) {
245+
int mc_tile_idx = idx / num_nc_panels;
246+
int m_idx = mc_tile_idx * mc;
247+
int mc_tile_size = std::min(mc, m - m_idx);
248+
249+
int nc_tile_idx = idx % num_nc_panels;
250+
int n_idx = nc_tile_idx * nc;
251+
int nc_tile_size = std::min(nc, n - n_idx);
252+
253+
int activation_data_offset = (m_idx / mr) * activation_data_size;
254+
int output_offset = m_idx * n + n_idx;
255+
int weight_data_offset = (n_idx / nr) * weight_data_size;
256+
int bias_offset = m_idx;
257+
258+
ukernel_config.kernel_fn(
259+
output + output_offset,
260+
/*output_m_stride=*/n,
261+
/*m=*/mc_tile_size,
262+
/*n=*/nc_tile_size,
263+
k,
264+
group_size,
265+
/*weight_data=*/(char*)weight_data + weight_data_offset,
266+
/*activation_data=*/activation_data_buffer + activation_data_offset,
267+
/*bias=*/bias + bias_offset,
268+
clamp_min,
269+
clamp_max);
270+
});
280271
}
281272
} // namespace internal
282273

torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,23 @@ include(CMakePrintHelpers)
1616
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
1717
include_directories(${TORCHAO_LIBRARIES})
1818

19-
add_library(
20-
torchao_dep
21-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
22-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
23-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
24-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
25-
)
26-
19+
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
2720

2821
add_executable(separate_function_wrappers separate_function_wrappers.cpp)
2922
target_link_libraries(
3023
separate_function_wrappers
3124
PRIVATE
32-
torchao_dep
25+
kernel_aarch64
3326
)
3427

3528
add_executable(stateful_class_wrapper stateful_class_wrapper.cpp)
3629
target_link_libraries(
3730
stateful_class_wrapper
3831
PRIVATE
39-
torchao_dep
32+
kernel_aarch64
4033
)
4134

35+
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
4236

43-
find_package(OpenMP)
44-
if(OpenMP_CXX_FOUND)
45-
target_link_libraries(separate_function_wrappers PUBLIC OpenMP::OpenMP_CXX)
46-
target_compile_definitions(separate_function_wrappers PRIVATE TORCHAO_PARALLEL_OMP=1)
47-
48-
target_link_libraries(stateful_class_wrapper PUBLIC OpenMP::OpenMP_CXX)
49-
target_compile_definitions(stateful_class_wrapper PRIVATE TORCHAO_PARALLEL_OMP=1)
50-
else()
51-
target_compile_definitions(separate_function_wrappers PRIVATE TORCHAO_PARALLEL_SINGLE_THREADED=1)
52-
target_compile_definitions(stateful_class_wrapper PRIVATE TORCHAO_PARALLEL_SINGLE_THREADED=1)
53-
endif()
37+
target_link_torchao_parallel_backend(stateful_class_wrapper "openmp")
38+
target_link_torchao_parallel_backend(separate_function_wrappers "openmp")

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
project(examples)
7+
project(torch_custom_op)
88

99
cmake_minimum_required(VERSION 3.19)
1010
set(CMAKE_CXX_STANDARD 17)
@@ -16,27 +16,15 @@ include(CMakePrintHelpers)
1616
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
1717
include_directories(${TORCHAO_LIBRARIES})
1818

19-
add_library(
20-
torchao_dep
21-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
22-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
23-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
24-
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
25-
)
26-
27-
include(FetchContent)
28-
FetchContent_Declare(pthreadpool
29-
GIT_REPOSITORY https://github.com/Maratyszcza/pthreadpool.git
30-
GIT_TAG master)
31-
FetchContent_MakeAvailable(
32-
pthreadpool)
19+
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
3320

3421
find_package(Torch REQUIRED)
35-
message("TORCH_INCLUDE_DIRS: ${TORCH_INCLUDE_DIRS}")
3622
include_directories("${TORCH_INCLUDE_DIRS}")
3723

3824
add_library(torch_custom_op SHARED torch_custom_op.cpp)
3925
target_link_libraries(torch_custom_op PRIVATE "${TORCH_LIBRARIES}")
40-
target_link_libraries(torch_custom_op PRIVATE torchao_dep)
41-
target_compile_definitions(torch_custom_op PRIVATE TORCHAO_PARALLEL_PTHREADPOOL=1)
42-
target_link_libraries(torch_custom_op PRIVATE pthreadpool)
26+
target_link_libraries(torch_custom_op PRIVATE kernel_aarch64)
27+
28+
include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
29+
set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP" CACHE STRING "Choose parallel backend to use for torchao parallelism (aten_openmp, openmp, pthreadpool, single_threaded)")
30+
target_link_torchao_parallel_backend(torch_custom_op "${TORCHAO_PARALLEL_BACKEND}")

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
1313
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
1414
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
1515
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
16+
-DTORCHAO_PARALLEL_BACKEND="aten_openmp" \
1617
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
1718
-B ${CMAKE_OUT}
1819
cmake --build ${CMAKE_OUT}

torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/torch_custom_op.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ at::Tensor pack_weights_without_zeros_cpu(
5050
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
5151
ukernel_config, n, /*target_panels_per_thread=*/1);
5252

53-
torchao::set_num_threads(torch::get_num_threads());
54-
5553
auto packed_weight_data_size =
5654
get_packed_weight_data_size(ukernel_config, n, k, group_size);
5755
at::Tensor packed_weights =
@@ -117,8 +115,6 @@ at::Tensor pack_weights_with_zeros_cpu(
117115
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
118116
ukernel_config, n, /*target_panels_per_thread=*/1);
119117

120-
torchao::set_num_threads(torch::get_num_threads());
121-
122118
auto packed_weight_data_size =
123119
get_packed_weight_data_size(ukernel_config, n, k, group_size);
124120
at::Tensor packed_weights =
@@ -227,8 +223,6 @@ at::Tensor linear_cpu(
227223
auto linear_scheduling_policy =
228224
LinearTileSchedulingPolicy::single_mc_parallel_nc;
229225

230-
torchao::set_num_threads(torch::get_num_threads());
231-
232226
auto activation_data_buffer_size = get_activation_data_buffer_size(
233227
ukernel_config,
234228
linear_tiling_params,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <torch/library.h>
9+
#include <torch/torch.h>
10+
#include <Aten/Parallel.h>
11+
12+
// F has signature [&](int64_t idx)
13+
template <typename F>
14+
void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) {
15+
at::parallel_for(begin, end, 1, [&](int64_t begin, int64_t end) {
16+
for (int64_t idx = begin; idx < end; idx++) {
17+
f(idx);
18+
}
19+
});
20+
}
21+
22+
void torchao::set_num_threads(int num_threads) {
23+
torch::set_num_threads(num_threads);
24+
}
25+
26+
int torchao::get_num_threads() {
27+
return torch::get_num_threads();
28+
}

torchao/experimental/kernels/cpu/parallel-impl.h

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)