@@ -51,7 +51,6 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
51
51
auto context = scop->makeContext (
52
52
std::unordered_map<std::string, int >{{" N" , N}, {" M" , M}});
53
53
scop = Scop::makeSpecializedScop (*scop, context);
54
-
55
54
Jit jit;
56
55
jit.codegenScop (" kernel_anon" , *scop);
57
56
auto fptr =
@@ -66,6 +65,176 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
66
65
checkRtol (Cc - C, {A, B}, N * M);
67
66
}
68
67
68
+ TEST (LLVMCodegen, BasicParallel) {
69
+ string tc = R"TC(
70
+ def fun(float(N, M) A, float(N, M) B) -> (C) {
71
+ C(i, j) = A(i, j) + B(i, j)
72
+ }
73
+ )TC" ;
74
+ auto N = 40 ;
75
+ auto M = 24 ;
76
+
77
+ auto ctx = isl::with_exceptions::globalIslCtx ();
78
+ auto scop = polyhedral::Scop::makeScop (ctx, tc);
79
+ auto context = scop->makeContext (
80
+ std::unordered_map<std::string, int >{{" N" , N}, {" M" , M}});
81
+ scop = Scop::makeSpecializedScop (*scop, context);
82
+ scop =
83
+ Scop::makeScheduled (*scop, SchedulerOptionsView (SchedulerOptionsProto ()));
84
+ Jit jit;
85
+ auto mod = jit.codegenScop (" kernel_anon" , *scop);
86
+ auto correct_llvm = R"LLVM(
87
+ ; Function Attrs: nounwind
88
+ define void @kernel_anon([24 x float]* noalias nocapture nonnull readonly %A, [24 x float]* noalias nocapture nonnull readonly %B, [24 x float]* noalias nocapture nonnull %C) local_unnamed_addr #0 {
89
+ entry:
90
+ %__cilkrts_sf = alloca %struct.__cilkrts_stack_frame, align 8
91
+ %0 = call %struct.__cilkrts_worker* @__cilkrts_get_tls_worker() #0
92
+ %1 = icmp eq %struct.__cilkrts_worker* %0, null
93
+ br i1 %1, label %slowpath.i, label %__cilkrts_enter_frame_1.exit
94
+
95
+ slowpath.i: ; preds = %entry
96
+ %2 = call %struct.__cilkrts_worker* @__cilkrts_bind_thread_1() #0
97
+ br label %__cilkrts_enter_frame_1.exit
98
+
99
+ __cilkrts_enter_frame_1.exit: ; preds = %entry, %slowpath.i
100
+ %.sink = phi i32 [ 16777344, %slowpath.i ], [ 16777216, %entry ]
101
+ %3 = phi %struct.__cilkrts_worker* [ %2, %slowpath.i ], [ %0, %entry ]
102
+ %4 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
103
+ store volatile i32 %.sink, i32* %4, align 8
104
+ %5 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %3, i64 0, i32 9
105
+ %6 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %5, align 8
106
+ %7 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 2
107
+ store volatile %struct.__cilkrts_stack_frame* %6, %struct.__cilkrts_stack_frame** %7, align 8
108
+ %8 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 3
109
+ store volatile %struct.__cilkrts_worker* %3, %struct.__cilkrts_worker** %8, align 8
110
+ store volatile %struct.__cilkrts_stack_frame* %__cilkrts_sf, %struct.__cilkrts_stack_frame** %5, align 8
111
+ %9 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 5
112
+ br label %loop_body
113
+
114
+ loop_body: ; preds = %loop_latch, %__cilkrts_enter_frame_1.exit
115
+ %c09 = phi i64 [ 0, %__cilkrts_enter_frame_1.exit ], [ %23, %loop_latch ]
116
+ %10 = bitcast [5 x i8*]* %9 to i8*
117
+ %11 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
118
+ %sunkaddr = getelementptr i8, i8* %11, i64 72
119
+ %12 = bitcast i8* %sunkaddr to i32*
120
+ %13 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
121
+ %sunkaddr16 = getelementptr i8, i8* %13, i64 76
122
+ %14 = bitcast i8* %sunkaddr16 to i16*
123
+ call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* %12, i16* %14) #0
124
+ %15 = call i8* @llvm.frameaddress(i32 0)
125
+ %16 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
126
+ %sunkaddr17 = getelementptr i8, i8* %16, i64 32
127
+ %17 = bitcast i8* %sunkaddr17 to i8**
128
+ store volatile i8* %15, i8** %17, align 8
129
+ %18 = call i8* @llvm.stacksave()
130
+ %19 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
131
+ %sunkaddr18 = getelementptr i8, i8* %19, i64 48
132
+ %20 = bitcast i8* %sunkaddr18 to i8**
133
+ store volatile i8* %18, i8** %20, align 8
134
+ %21 = call i32 @llvm.eh.sjlj.setjmp(i8* %10) #3
135
+ %22 = icmp eq i32 %21, 0
136
+ br i1 %22, label %loop_body.split, label %loop_latch
137
+
138
+ loop_body.split: ; preds = %loop_body
139
+ call fastcc void @kernel_anon_loop_body2.cilk([24 x float]* %C, i64 %c09, [24 x float]* %B, [24 x float]* %A)
140
+ br label %loop_latch
141
+
142
+ loop_latch: ; preds = %loop_body.split, %loop_body
143
+ %23 = add nuw nsw i64 %c09, 1
144
+ %exitcond = icmp eq i64 %23, 40
145
+ br i1 %exitcond, label %loop_exit, label %loop_body
146
+
147
+ loop_exit: ; preds = %loop_latch
148
+ %24 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
149
+ %25 = load volatile i32, i32* %24, align 8
150
+ %26 = and i32 %25, 2
151
+ %27 = icmp eq i32 %26, 0
152
+ br i1 %27, label %__cilk_sync.exit, label %cilk.sync.savestate.i
153
+
154
+ cilk.sync.savestate.i: ; preds = %loop_exit
155
+ %28 = bitcast [5 x i8*]* %9 to i8*
156
+ %29 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
157
+ %sunkaddr19 = getelementptr i8, i8* %29, i64 16
158
+ %30 = bitcast i8* %sunkaddr19 to %struct.__cilkrts_worker**
159
+ %31 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %30, align 8
160
+ %32 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
161
+ %sunkaddr20 = getelementptr i8, i8* %32, i64 72
162
+ %33 = bitcast i8* %sunkaddr20 to i32*
163
+ %34 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
164
+ %sunkaddr21 = getelementptr i8, i8* %34, i64 76
165
+ %35 = bitcast i8* %sunkaddr21 to i16*
166
+ call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* nonnull %33, i16* nonnull %35) #0
167
+ %36 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
168
+ %sunkaddr22 = getelementptr i8, i8* %36, i64 32
169
+ %37 = bitcast i8* %sunkaddr22 to i8**
170
+ store volatile i8* %15, i8** %37, align 8
171
+ %38 = call i8* @llvm.stacksave()
172
+ %39 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
173
+ %sunkaddr23 = getelementptr i8, i8* %39, i64 48
174
+ %40 = bitcast i8* %sunkaddr23 to i8**
175
+ store volatile i8* %38, i8** %40, align 8
176
+ %41 = call i32 @llvm.eh.sjlj.setjmp(i8* nonnull %28) #3
177
+ %42 = icmp eq i32 %41, 0
178
+ br i1 %42, label %cilk.sync.runtimecall.i, label %cilk.sync.excepting.i
179
+
180
+ cilk.sync.runtimecall.i: ; preds = %cilk.sync.savestate.i
181
+ call void @__cilkrts_sync(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
182
+ br label %__cilk_sync.exit
183
+
184
+ cilk.sync.excepting.i: ; preds = %cilk.sync.savestate.i
185
+ %43 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
186
+ %44 = load volatile i32, i32* %43, align 8
187
+ %45 = and i32 %44, 16
188
+ %46 = icmp eq i32 %45, 0
189
+ br i1 %46, label %__cilk_sync.exit, label %cilk.sync.rethrow.i
190
+
191
+ cilk.sync.rethrow.i: ; preds = %cilk.sync.excepting.i
192
+ call void @__cilkrts_rethrow(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #4
193
+ unreachable
194
+
195
+ __cilk_sync.exit: ; preds = %loop_exit, %cilk.sync.runtimecall.i, %cilk.sync.excepting.i
196
+ %47 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
197
+ %48 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
198
+ %sunkaddr24 = getelementptr i8, i8* %48, i64 16
199
+ %49 = bitcast i8* %sunkaddr24 to %struct.__cilkrts_worker**
200
+ %50 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
201
+ %51 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %50, i64 0, i32 12, i32 0
202
+ %52 = load i64, i64* %51, align 8
203
+ %53 = add i64 %52, 1
204
+ store i64 %53, i64* %51, align 8
205
+ %54 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
206
+ %55 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
207
+ %sunkaddr25 = getelementptr i8, i8* %55, i64 8
208
+ %56 = bitcast i8* %sunkaddr25 to %struct.__cilkrts_stack_frame**
209
+ %57 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %56, align 8
210
+ %58 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %54, i64 0, i32 9
211
+ store volatile %struct.__cilkrts_stack_frame* %57, %struct.__cilkrts_stack_frame** %58, align 8
212
+ store volatile %struct.__cilkrts_stack_frame* null, %struct.__cilkrts_stack_frame** %56, align 8
213
+ %59 = load volatile i32, i32* %47, align 8
214
+ %60 = icmp eq i32 %59, 16777216
215
+ br i1 %60, label %__cilk_parent_epilogue.exit, label %body.i
216
+
217
+ body.i: ; preds = %__cilk_sync.exit
218
+ call void @__cilkrts_leave_frame(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
219
+ br label %__cilk_parent_epilogue.exit
220
+
221
+ __cilk_parent_epilogue.exit: ; preds = %__cilk_sync.exit, %body.i
222
+ ret void
223
+ }
224
+ )LLVM" ;
225
+ EXPECT_EQ (correct_llvm, toString (mod->getFunction (" kernel_anon" )));
226
+ auto fptr =
227
+ (void (*)(float *, float *, float *))jit.getSymbolAddress (" kernel_anon" );
228
+
229
+ at::Tensor A = at::CPU (at::kFloat ).rand ({N, M});
230
+ at::Tensor B = at::CPU (at::kFloat ).rand ({N, M});
231
+ at::Tensor C = at::CPU (at::kFloat ).rand ({N, M});
232
+ at::Tensor Cc = A + B;
233
+ fptr (A.data <float >(), B.data <float >(), C.data <float >());
234
+
235
+ checkRtol (Cc - C, {A, B}, N * M);
236
+ }
237
+
69
238
TEST (LLVMCodegen, DISABLED_BasicExecutionEngine) {
70
239
string tc = R"TC(
71
240
def fun(float(N, M) A, float(N, M) B) -> (C) {
0 commit comments