Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 70127a6

Browse files
Merge pull request #219 from facebookresearch/pcpumap
Parallel CPU mapper
2 parents aa3293f + 6be0ca0 commit 70127a6

File tree

8 files changed

+150
-30
lines changed

8 files changed

+150
-30
lines changed

build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ else
1717
fi
1818
WITH_PYTHON_C2=${WITH_PYTHON_C2:=OFF}
1919
WITH_NNPACK=${WITH_NNPACK:=OFF}
20+
WITH_TAPIR=${WITH_TAPIR:=ON}
2021
PYTHON=${PYTHON:="`which python3`"}
2122
PROTOC=${PROTOC:="`which protoc`"}
2223
CORES=${CORES:=32}
@@ -401,6 +402,7 @@ function install_tc() {
401402
rm -rf *
402403
VERBOSE=${VERBOSE} ${CMAKE_VERSION} -DWITH_CAFFE2=${WITH_CAFFE2} \
403404
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
405+
-DWITH_TAPIR=${WITH_TAPIR} \
404406
-DPYTHON_EXECUTABLE=${PYTHON} \
405407
-DHALIDE_PREFIX=${INSTALL_PREFIX} \
406408
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \

include/tc/core/polyhedral/codegen_llvm.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,29 @@
2020

2121
#include "llvm/IR/LLVMContext.h"
2222
#include "llvm/IR/Module.h"
23+
#include "llvm/Support/raw_ostream.h"
2324
#include "llvm/Target/TargetMachine.h"
2425

2526
#include "Halide.h"
2627

2728
namespace tc {
29+
30+
static inline std::string toString(llvm::Value* llvmObject) {
31+
std::string output;
32+
llvm::raw_string_ostream rso(output);
33+
llvmObject->print(rso);
34+
rso.str();
35+
return output;
36+
}
37+
38+
static inline std::string toString(llvm::Module* llvmObject) {
39+
std::string output;
40+
llvm::raw_string_ostream rso(output);
41+
llvmObject->print(rso, nullptr, false, true);
42+
rso.str();
43+
return output;
44+
}
45+
2846
namespace polyhedral {
2947
struct Scop;
3048

include/tc/core/polyhedral/llvm_jit.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ class Jit {
3838
public:
3939
Jit();
4040

41-
void codegenScop(
41+
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
42+
std::shared_ptr<llvm::Module> codegenScop(
4243
const std::string& specializedName,
4344
const polyhedral::Scop& scop);
44-
45-
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
46-
ModuleHandle addModule(std::unique_ptr<llvm::Module> M);
45+
ModuleHandle addModule(std::shared_ptr<llvm::Module> M);
4746
void removeModule(ModuleHandle H);
4847

4948
llvm::JITSymbol findSymbol(const std::string name);

include/tc/core/polyhedral/scop.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ struct Scop {
341341
static std::unique_ptr<Scop> makeScheduled(
342342
const Scop& scop,
343343
const SchedulerOptionsView& schedulerOptions);
344-
345344
// Tile the outermost band.
346345
// Splits the band into tile loop band and point loop band where point loops
347346
// have fixed trip counts specified in "tiling", and returns a pointer to the

src/core/polyhedral/codegen_llvm.cc

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,6 @@ using namespace Halide;
5555

5656
namespace tc {
5757

58-
namespace {
59-
template <typename T>
60-
std::string toString(T* llvmObject) {
61-
std::string output;
62-
llvm::raw_string_ostream rso(output);
63-
llvmObject->print(rso, nullptr, false, true);
64-
rso.str();
65-
return output;
66-
}
67-
} // namespace
68-
6958
namespace halide2isl {
7059
isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
7160
}
@@ -217,6 +206,9 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
217206
using CodeGen_X86::sym_push;
218207

219208
void init_module() override {
209+
const char* llvm_args[] = {"tc (LLVM argument parsing)", nullptr};
210+
llvm::cl::ParseCommandLineOptions(
211+
sizeof(llvm_args) / sizeof(*llvm_args) - 1, llvm_args);
220212
init_context();
221213
module =
222214
llvm::make_unique<llvm::Module>("TensorComprehensionsModule", *context);
@@ -311,14 +303,13 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
311303
functionPassManager.doInitialization();
312304
for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
313305
functionPassManager.run(*i);
306+
}
314307

315-
functionPassManager.doFinalization();
316-
modulePassManager.run(*module);
308+
functionPassManager.doFinalization();
309+
modulePassManager.run(*module);
317310

318-
LOG_IF(INFO, FLAGS_llvm_dump_after_opt)
319-
<< "[LLVM-IR] After optimization:\n"
320-
<< toString(module.get());
321-
}
311+
LOG_IF(INFO, FLAGS_llvm_dump_after_opt) << "[LLVM-IR] After optimization:\n"
312+
<< toString(module.get());
322313
}
323314
};
324315

@@ -492,8 +483,7 @@ class LLVMCodegen {
492483

493484
// TODO: integrate query ISL as to whether the relevant loop ought be
494485
// parallelized
495-
bool parallel = false;
496-
486+
bool parallel = isl_ast_node_for_is_coincident(node.get());
497487
llvm::Value* SyncRegion = nullptr;
498488

499489
#ifdef TAPIR_VERSION_MAJOR

src/core/polyhedral/llvm_jit.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,20 @@ Jit::Jit()
8282
}
8383
}
8484

85-
void Jit::codegenScop(
85+
std::shared_ptr<Module> Jit::codegenScop(
8686
const std::string& specializedName,
8787
const polyhedral::Scop& scop) {
88-
addModule(emitLLVMKernel(
89-
specializedName, scop, getTargetMachine().createDataLayout()));
88+
std::shared_ptr<Module> mod = emitLLVMKernel(
89+
specializedName, scop, getTargetMachine().createDataLayout());
90+
addModule(mod);
91+
return mod;
9092
}
9193

9294
TargetMachine& Jit::getTargetMachine() {
9395
return *TM_;
9496
}
9597

96-
Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
98+
Jit::ModuleHandle Jit::addModule(std::shared_ptr<Module> M) {
9799
M->setTargetTriple(TM_->getTargetTriple().str());
98100
auto Resolver = orc::createLambdaResolver(
99101
[&](const std::string& Name) {
@@ -107,7 +109,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
107109
return JITSymbol(nullptr);
108110
});
109111

110-
auto res = compileLayer_.addModule(std::move(M), std::move(Resolver));
112+
auto res = compileLayer_.addModule(M, std::move(Resolver));
111113
CHECK(res) << "Failed to jit compile.";
112114
return *res;
113115
}

test/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ foreach(i ${CORE_TEST_FILES})
100100
target_link_libraries(${i} ${GOOGLE_LIBS} tc_core_cuda_no_sdk)
101101
endforeach()
102102

103-
104103
add_executable(test_mapper_llvm test_mapper_llvm.cc)
105104
add_test(test_mapper_llvm test_mapper_llvm)
106105
target_link_libraries(
@@ -112,6 +111,19 @@ target_link_libraries(
112111

113112
tc_core_cpu tc_lang)
114113

114+
if (WITH_TAPIR)
115+
add_executable(test_mapper_tapir test_mapper_tapir.cc)
116+
add_test(test_mapper_tapir test_mapper_tapir)
117+
target_link_libraries(
118+
test_mapper_tapir
119+
120+
${GOOGLE_LIBS}
121+
${ATEN_LIBRARIES}
122+
-lLLVM
123+
124+
tc_core_cpu tc_lang)
125+
endif()
126+
115127
################################################################################
116128
# TensorComprehensions tests
117129
# No real need for NVCC if we only use NVRTC

test/test_mapper_tapir.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <gflags/gflags.h>
18+
#include <glog/logging.h>
19+
#include <gtest/gtest.h>
20+
21+
#include <ATen/ATen.h>
22+
23+
#include <llvm/IR/InstIterator.h>
24+
#include <llvm/IR/Instructions.h>
25+
26+
#include "tc/aten/utils.h"
27+
#include "tc/core/cpu/cpu_tc_executor.h"
28+
#include "tc/core/execution_engine.h"
29+
#include "tc/core/mapping_options.h"
30+
#include "tc/core/polyhedral/codegen_llvm.h"
31+
#include "tc/core/polyhedral/llvm_jit.h"
32+
#include "tc/core/polyhedral/scop.h"
33+
#include "tc/core/scope_guard.h"
34+
35+
#include "test_harness_aten.h"
36+
37+
using namespace std;
38+
39+
using namespace tc;
40+
using namespace tc::polyhedral;
41+
using namespace tc::polyhedral::detail;
42+
43+
TEST(TapirCodegen, BasicParallel) {
44+
string tc = R"TC(
45+
def fun(float(N, M) A, float(N, M) B) -> (C) {
46+
C(n, m) = A(n, m) + B(n, m)
47+
}
48+
)TC";
49+
auto N = 40;
50+
auto M = 24;
51+
52+
auto ctx = isl::with_exceptions::globalIslCtx();
53+
auto scop = polyhedral::Scop::makeScop(ctx, tc);
54+
auto context = scop->makeContext(
55+
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
56+
scop = Scop::makeSpecializedScop(*scop, context);
57+
SchedulerOptionsProto sop;
58+
SchedulerOptionsView sov(sop);
59+
scop = Scop::makeScheduled(*scop, sov);
60+
Jit jit;
61+
auto mod = jit.codegenScop("kernel_anon", *scop);
62+
auto fn = mod->getFunction("kernel_anon");
63+
64+
std::set<string> calledFunctions;
65+
for (llvm::inst_iterator I = llvm::inst_begin(fn), E = llvm::inst_end(fn);
66+
I != E;
67+
++I) {
68+
if (llvm::CallInst* c = llvm::dyn_cast<llvm::CallInst>(&*I)) {
69+
if (auto called = c->getCalledFunction()) {
70+
calledFunctions.insert(called->getName());
71+
}
72+
}
73+
}
74+
75+
ASSERT_NE(0, calledFunctions.count("__cilkrts_get_tls_worker"));
76+
ASSERT_NE(0, calledFunctions.count("__cilkrts_bind_thread_1"));
77+
ASSERT_NE(0, calledFunctions.count("llvm.stacksave"));
78+
ASSERT_NE(0, calledFunctions.count("__cilkrts_sync"));
79+
80+
auto fptr =
81+
(void (*)(float*, float*, float*))jit.getSymbolAddress("kernel_anon");
82+
83+
at::Tensor A = at::CPU(at::kFloat).rand({N, M});
84+
at::Tensor B = at::CPU(at::kFloat).rand({N, M});
85+
at::Tensor C = at::CPU(at::kFloat).rand({N, M});
86+
at::Tensor Cc = A + B;
87+
fptr(A.data<float>(), B.data<float>(), C.data<float>());
88+
89+
checkRtol(Cc - C, {A, B}, N * M);
90+
}
91+
92+
int main(int argc, char** argv) {
93+
::testing::InitGoogleTest(&argc, argv);
94+
::gflags::ParseCommandLineFlags(&argc, &argv, true);
95+
::google::InitGoogleLogging(argv[0]);
96+
initialize_llvm();
97+
return RUN_ALL_TESTS();
98+
}

0 commit comments

Comments
 (0)