Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6e822ac

Browse files
committed
Parallel CPU mapper
1 parent aa3293f commit 6e822ac

File tree

8 files changed

+224
-30
lines changed

8 files changed

+224
-30
lines changed

include/tc/core/mapping_options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class SchedulerOptionsView {
192192
/// Construct a view that refers to a protocol buffers message.
193193
SchedulerOptionsView(const SchedulerOptionsView&) = default;
194194
SchedulerOptionsView(SchedulerOptionsProto& buf) : proto(buf) {}
195+
SchedulerOptionsView(SchedulerOptionsProto&& buf) : proto(buf) {}
195196

196197
/// Assign the values from another view.
197198
inline SchedulerOptionsView& operator=(const SchedulerOptionsView&);

include/tc/core/polyhedral/codegen_llvm.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,29 @@
2020

2121
#include "llvm/IR/LLVMContext.h"
2222
#include "llvm/IR/Module.h"
23+
#include "llvm/Support/raw_ostream.h"
2324
#include "llvm/Target/TargetMachine.h"
2425

2526
#include "Halide.h"
2627

2728
namespace tc {
29+
30+
static inline std::string toString(llvm::Value* llvmObject) {
31+
std::string output;
32+
llvm::raw_string_ostream rso(output);
33+
llvmObject->print(rso);
34+
rso.str();
35+
return output;
36+
}
37+
38+
static inline std::string toString(llvm::Module* llvmObject) {
39+
std::string output;
40+
llvm::raw_string_ostream rso(output);
41+
llvmObject->print(rso, nullptr, false, true);
42+
rso.str();
43+
return output;
44+
}
45+
2846
namespace polyhedral {
2947
struct Scop;
3048

include/tc/core/polyhedral/llvm_jit.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ class Jit {
3838
public:
3939
Jit();
4040

41-
void codegenScop(
41+
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
42+
std::shared_ptr<llvm::Module> codegenScop(
4243
const std::string& specializedName,
4344
const polyhedral::Scop& scop);
44-
45-
using ModuleHandle = decltype(compileLayer_)::ModuleHandleT;
46-
ModuleHandle addModule(std::unique_ptr<llvm::Module> M);
45+
ModuleHandle addModule(std::shared_ptr<llvm::Module> M);
4746
void removeModule(ModuleHandle H);
4847

4948
llvm::JITSymbol findSymbol(const std::string name);

include/tc/core/polyhedral/scop.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,12 @@ struct Scop {
340340
// Create a Scop scheduled with a given scheduling strategy.
341341
static std::unique_ptr<Scop> makeScheduled(
342342
const Scop& scop,
343-
const SchedulerOptionsView& schedulerOptions);
343+
const SchedulerOptionsView&& schedulerOptions);
344344

345+
// Create a Scop scheduled with a given scheduling strategy.
346+
static std::unique_ptr<Scop> makeScheduled(
347+
const Scop& scop,
348+
const SchedulerOptionsView& schedulerOptions);
345349
// Tile the outermost band.
346350
// Splits the band into tile loop band and point loop band where point loops
347351
// have fixed trip counts specified in "tiling", and returns a pointer to the

src/core/polyhedral/codegen_llvm.cc

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,6 @@ using namespace Halide;
5555

5656
namespace tc {
5757

58-
namespace {
59-
template <typename T>
60-
std::string toString(T* llvmObject) {
61-
std::string output;
62-
llvm::raw_string_ostream rso(output);
63-
llvmObject->print(rso, nullptr, false, true);
64-
rso.str();
65-
return output;
66-
}
67-
} // namespace
68-
6958
namespace halide2isl {
7059
isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
7160
}
@@ -217,6 +206,9 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
217206
using CodeGen_X86::sym_push;
218207

219208
void init_module() override {
209+
const char* llvm_args[] = {"tc (LLVM argument parsing)", nullptr};
210+
llvm::cl::ParseCommandLineOptions(
211+
sizeof(llvm_args) / sizeof(*llvm_args) - 1, llvm_args);
220212
init_context();
221213
module =
222214
llvm::make_unique<llvm::Module>("TensorComprehensionsModule", *context);
@@ -311,14 +303,13 @@ class CodeGen_TC : public Halide::Internal::CodeGen_X86 {
311303
functionPassManager.doInitialization();
312304
for (llvm::Module::iterator i = module->begin(); i != module->end(); i++) {
313305
functionPassManager.run(*i);
306+
}
314307

315-
functionPassManager.doFinalization();
316-
modulePassManager.run(*module);
308+
functionPassManager.doFinalization();
309+
modulePassManager.run(*module);
317310

318-
LOG_IF(INFO, FLAGS_llvm_dump_after_opt)
319-
<< "[LLVM-IR] After optimization:\n"
320-
<< toString(module.get());
321-
}
311+
LOG_IF(INFO, FLAGS_llvm_dump_after_opt) << "[LLVM-IR] After optimization:\n"
312+
<< toString(module.get());
322313
}
323314
};
324315

@@ -492,8 +483,7 @@ class LLVMCodegen {
492483

493484
// TODO: integrate query ISL as to whether the relevant loop ought be
494485
// parallelized
495-
bool parallel = false;
496-
486+
bool parallel = isl_ast_node_for_is_coincident(node.get());
497487
llvm::Value* SyncRegion = nullptr;
498488

499489
#ifdef TAPIR_VERSION_MAJOR

src/core/polyhedral/llvm_jit.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,20 @@ Jit::Jit()
8282
}
8383
}
8484

85-
void Jit::codegenScop(
85+
std::shared_ptr<Module> Jit::codegenScop(
8686
const std::string& specializedName,
8787
const polyhedral::Scop& scop) {
88-
addModule(emitLLVMKernel(
89-
specializedName, scop, getTargetMachine().createDataLayout()));
88+
std::shared_ptr<Module> mod = emitLLVMKernel(
89+
specializedName, scop, getTargetMachine().createDataLayout());
90+
addModule(mod);
91+
return mod;
9092
}
9193

9294
TargetMachine& Jit::getTargetMachine() {
9395
return *TM_;
9496
}
9597

96-
Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
98+
Jit::ModuleHandle Jit::addModule(std::shared_ptr<Module> M) {
9799
M->setTargetTriple(TM_->getTargetTriple().str());
98100
auto Resolver = orc::createLambdaResolver(
99101
[&](const std::string& Name) {
@@ -107,7 +109,7 @@ Jit::ModuleHandle Jit::addModule(std::unique_ptr<Module> M) {
107109
return JITSymbol(nullptr);
108110
});
109111

110-
auto res = compileLayer_.addModule(std::move(M), std::move(Resolver));
112+
auto res = compileLayer_.addModule(M, std::move(Resolver));
111113
CHECK(res) << "Failed to jit compile.";
112114
return *res;
113115
}

src/core/polyhedral/scop.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,17 @@ std::unique_ptr<Scop> Scop::makeScheduled(
463463
return s;
464464
}
465465

466+
std::unique_ptr<Scop> Scop::makeScheduled(
467+
const Scop& scop,
468+
const SchedulerOptionsView&& schedulerOptions) {
469+
auto s = makeScop(scop);
470+
auto constraints = makeScheduleConstraints(*s, schedulerOptions);
471+
s->scheduleTreeUPtr = computeSchedule(constraints, schedulerOptions);
472+
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After scheduling:" << std::endl
473+
<< *s->scheduleTreeUPtr;
474+
return s;
475+
}
476+
466477
namespace {
467478

468479
/*

test/test_mapper_llvm.cc

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
5151
auto context = scop->makeContext(
5252
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
5353
scop = Scop::makeSpecializedScop(*scop, context);
54-
5554
Jit jit;
5655
jit.codegenScop("kernel_anon", *scop);
5756
auto fptr =
@@ -66,6 +65,176 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
6665
checkRtol(Cc - C, {A, B}, N * M);
6766
}
6867

68+
TEST(LLVMCodegen, BasicParallel) {
69+
string tc = R"TC(
70+
def fun(float(N, M) A, float(N, M) B) -> (C) {
71+
C(i, j) = A(i, j) + B(i, j)
72+
}
73+
)TC";
74+
auto N = 40;
75+
auto M = 24;
76+
77+
auto ctx = isl::with_exceptions::globalIslCtx();
78+
auto scop = polyhedral::Scop::makeScop(ctx, tc);
79+
auto context = scop->makeContext(
80+
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
81+
scop = Scop::makeSpecializedScop(*scop, context);
82+
scop =
83+
Scop::makeScheduled(*scop, SchedulerOptionsView(SchedulerOptionsProto()));
84+
Jit jit;
85+
auto mod = jit.codegenScop("kernel_anon", *scop);
86+
auto correct_llvm = R"LLVM(
87+
; Function Attrs: nounwind
88+
define void @kernel_anon([24 x float]* noalias nocapture nonnull readonly %A, [24 x float]* noalias nocapture nonnull readonly %B, [24 x float]* noalias nocapture nonnull %C) local_unnamed_addr #0 {
89+
entry:
90+
%__cilkrts_sf = alloca %struct.__cilkrts_stack_frame, align 8
91+
%0 = call %struct.__cilkrts_worker* @__cilkrts_get_tls_worker() #0
92+
%1 = icmp eq %struct.__cilkrts_worker* %0, null
93+
br i1 %1, label %slowpath.i, label %__cilkrts_enter_frame_1.exit
94+
95+
slowpath.i: ; preds = %entry
96+
%2 = call %struct.__cilkrts_worker* @__cilkrts_bind_thread_1() #0
97+
br label %__cilkrts_enter_frame_1.exit
98+
99+
__cilkrts_enter_frame_1.exit: ; preds = %entry, %slowpath.i
100+
%.sink = phi i32 [ 16777344, %slowpath.i ], [ 16777216, %entry ]
101+
%3 = phi %struct.__cilkrts_worker* [ %2, %slowpath.i ], [ %0, %entry ]
102+
%4 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
103+
store volatile i32 %.sink, i32* %4, align 8
104+
%5 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %3, i64 0, i32 9
105+
%6 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %5, align 8
106+
%7 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 2
107+
store volatile %struct.__cilkrts_stack_frame* %6, %struct.__cilkrts_stack_frame** %7, align 8
108+
%8 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 3
109+
store volatile %struct.__cilkrts_worker* %3, %struct.__cilkrts_worker** %8, align 8
110+
store volatile %struct.__cilkrts_stack_frame* %__cilkrts_sf, %struct.__cilkrts_stack_frame** %5, align 8
111+
%9 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 5
112+
br label %loop_body
113+
114+
loop_body: ; preds = %loop_latch, %__cilkrts_enter_frame_1.exit
115+
%c09 = phi i64 [ 0, %__cilkrts_enter_frame_1.exit ], [ %23, %loop_latch ]
116+
%10 = bitcast [5 x i8*]* %9 to i8*
117+
%11 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
118+
%sunkaddr = getelementptr i8, i8* %11, i64 72
119+
%12 = bitcast i8* %sunkaddr to i32*
120+
%13 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
121+
%sunkaddr16 = getelementptr i8, i8* %13, i64 76
122+
%14 = bitcast i8* %sunkaddr16 to i16*
123+
call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* %12, i16* %14) #0
124+
%15 = call i8* @llvm.frameaddress(i32 0)
125+
%16 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
126+
%sunkaddr17 = getelementptr i8, i8* %16, i64 32
127+
%17 = bitcast i8* %sunkaddr17 to i8**
128+
store volatile i8* %15, i8** %17, align 8
129+
%18 = call i8* @llvm.stacksave()
130+
%19 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
131+
%sunkaddr18 = getelementptr i8, i8* %19, i64 48
132+
%20 = bitcast i8* %sunkaddr18 to i8**
133+
store volatile i8* %18, i8** %20, align 8
134+
%21 = call i32 @llvm.eh.sjlj.setjmp(i8* %10) #3
135+
%22 = icmp eq i32 %21, 0
136+
br i1 %22, label %loop_body.split, label %loop_latch
137+
138+
loop_body.split: ; preds = %loop_body
139+
call fastcc void @kernel_anon_loop_body2.cilk([24 x float]* %C, i64 %c09, [24 x float]* %B, [24 x float]* %A)
140+
br label %loop_latch
141+
142+
loop_latch: ; preds = %loop_body.split, %loop_body
143+
%23 = add nuw nsw i64 %c09, 1
144+
%exitcond = icmp eq i64 %23, 40
145+
br i1 %exitcond, label %loop_exit, label %loop_body
146+
147+
loop_exit: ; preds = %loop_latch
148+
%24 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
149+
%25 = load volatile i32, i32* %24, align 8
150+
%26 = and i32 %25, 2
151+
%27 = icmp eq i32 %26, 0
152+
br i1 %27, label %__cilk_sync.exit, label %cilk.sync.savestate.i
153+
154+
cilk.sync.savestate.i: ; preds = %loop_exit
155+
%28 = bitcast [5 x i8*]* %9 to i8*
156+
%29 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
157+
%sunkaddr19 = getelementptr i8, i8* %29, i64 16
158+
%30 = bitcast i8* %sunkaddr19 to %struct.__cilkrts_worker**
159+
%31 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %30, align 8
160+
%32 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
161+
%sunkaddr20 = getelementptr i8, i8* %32, i64 72
162+
%33 = bitcast i8* %sunkaddr20 to i32*
163+
%34 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
164+
%sunkaddr21 = getelementptr i8, i8* %34, i64 76
165+
%35 = bitcast i8* %sunkaddr21 to i16*
166+
call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* nonnull %33, i16* nonnull %35) #0
167+
%36 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
168+
%sunkaddr22 = getelementptr i8, i8* %36, i64 32
169+
%37 = bitcast i8* %sunkaddr22 to i8**
170+
store volatile i8* %15, i8** %37, align 8
171+
%38 = call i8* @llvm.stacksave()
172+
%39 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
173+
%sunkaddr23 = getelementptr i8, i8* %39, i64 48
174+
%40 = bitcast i8* %sunkaddr23 to i8**
175+
store volatile i8* %38, i8** %40, align 8
176+
%41 = call i32 @llvm.eh.sjlj.setjmp(i8* nonnull %28) #3
177+
%42 = icmp eq i32 %41, 0
178+
br i1 %42, label %cilk.sync.runtimecall.i, label %cilk.sync.excepting.i
179+
180+
cilk.sync.runtimecall.i: ; preds = %cilk.sync.savestate.i
181+
call void @__cilkrts_sync(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
182+
br label %__cilk_sync.exit
183+
184+
cilk.sync.excepting.i: ; preds = %cilk.sync.savestate.i
185+
%43 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
186+
%44 = load volatile i32, i32* %43, align 8
187+
%45 = and i32 %44, 16
188+
%46 = icmp eq i32 %45, 0
189+
br i1 %46, label %__cilk_sync.exit, label %cilk.sync.rethrow.i
190+
191+
cilk.sync.rethrow.i: ; preds = %cilk.sync.excepting.i
192+
call void @__cilkrts_rethrow(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #4
193+
unreachable
194+
195+
__cilk_sync.exit: ; preds = %loop_exit, %cilk.sync.runtimecall.i, %cilk.sync.excepting.i
196+
%47 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
197+
%48 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
198+
%sunkaddr24 = getelementptr i8, i8* %48, i64 16
199+
%49 = bitcast i8* %sunkaddr24 to %struct.__cilkrts_worker**
200+
%50 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
201+
%51 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %50, i64 0, i32 12, i32 0
202+
%52 = load i64, i64* %51, align 8
203+
%53 = add i64 %52, 1
204+
store i64 %53, i64* %51, align 8
205+
%54 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
206+
%55 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
207+
%sunkaddr25 = getelementptr i8, i8* %55, i64 8
208+
%56 = bitcast i8* %sunkaddr25 to %struct.__cilkrts_stack_frame**
209+
%57 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %56, align 8
210+
%58 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %54, i64 0, i32 9
211+
store volatile %struct.__cilkrts_stack_frame* %57, %struct.__cilkrts_stack_frame** %58, align 8
212+
store volatile %struct.__cilkrts_stack_frame* null, %struct.__cilkrts_stack_frame** %56, align 8
213+
%59 = load volatile i32, i32* %47, align 8
214+
%60 = icmp eq i32 %59, 16777216
215+
br i1 %60, label %__cilk_parent_epilogue.exit, label %body.i
216+
217+
body.i: ; preds = %__cilk_sync.exit
218+
call void @__cilkrts_leave_frame(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
219+
br label %__cilk_parent_epilogue.exit
220+
221+
__cilk_parent_epilogue.exit: ; preds = %__cilk_sync.exit, %body.i
222+
ret void
223+
}
224+
)LLVM";
225+
EXPECT_EQ(correct_llvm, toString(mod->getFunction("kernel_anon")));
226+
auto fptr =
227+
(void (*)(float*, float*, float*))jit.getSymbolAddress("kernel_anon");
228+
229+
at::Tensor A = at::CPU(at::kFloat).rand({N, M});
230+
at::Tensor B = at::CPU(at::kFloat).rand({N, M});
231+
at::Tensor C = at::CPU(at::kFloat).rand({N, M});
232+
at::Tensor Cc = A + B;
233+
fptr(A.data<float>(), B.data<float>(), C.data<float>());
234+
235+
checkRtol(Cc - C, {A, B}, N * M);
236+
}
237+
69238
TEST(LLVMCodegen, DISABLED_BasicExecutionEngine) {
70239
string tc = R"TC(
71240
def fun(float(N, M) A, float(N, M) B) -> (C) {

0 commit comments

Comments
 (0)