@@ -51,179 +51,9 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
51
51
auto context = scop->makeContext (
52
52
std::unordered_map<std::string, int >{{" N" , N}, {" M" , M}});
53
53
scop = Scop::makeSpecializedScop (*scop, context);
54
- Jit jit;
55
- jit.codegenScop (" kernel_anon" , *scop);
56
- auto fptr =
57
- (void (*)(float *, float *, float *))jit.getSymbolAddress (" kernel_anon" );
58
-
59
- at::Tensor A = at::CPU (at::kFloat ).rand ({N, M});
60
- at::Tensor B = at::CPU (at::kFloat ).rand ({N, M});
61
- at::Tensor C = at::CPU (at::kFloat ).rand ({N, M});
62
- at::Tensor Cc = A + B;
63
- fptr (A.data <float >(), B.data <float >(), C.data <float >());
64
54
65
- checkRtol (Cc - C, {A, B}, N * M);
66
- }
67
-
68
- TEST (LLVMCodegen, BasicParallel) {
69
- string tc = R"TC(
70
- def fun(float(N, M) A, float(N, M) B) -> (C) {
71
- C(n, m) = A(n, m) + B(n, m)
72
- }
73
- )TC" ;
74
- auto N = 40 ;
75
- auto M = 24 ;
76
-
77
- auto ctx = isl::with_exceptions::globalIslCtx ();
78
- auto scop = polyhedral::Scop::makeScop (ctx, tc);
79
- auto context = scop->makeContext (
80
- std::unordered_map<std::string, int >{{" N" , N}, {" M" , M}});
81
- scop = Scop::makeSpecializedScop (*scop, context);
82
- SchedulerOptionsProto sop;
83
- SchedulerOptionsView sov (sop);
84
- scop = Scop::makeScheduled (*scop, sov);
85
55
Jit jit;
86
- auto mod = jit.codegenScop (" kernel_anon" , *scop);
87
- auto correct_llvm = R"LLVM(
88
- ; Function Attrs: nounwind
89
- define void @kernel_anon([24 x float]* noalias nocapture nonnull readonly %A, [24 x float]* noalias nocapture nonnull readonly %B, [24 x float]* noalias nocapture nonnull %C) local_unnamed_addr #0 {
90
- entry:
91
- %__cilkrts_sf = alloca %struct.__cilkrts_stack_frame, align 8
92
- %0 = call %struct.__cilkrts_worker* @__cilkrts_get_tls_worker() #0
93
- %1 = icmp eq %struct.__cilkrts_worker* %0, null
94
- br i1 %1, label %slowpath.i, label %__cilkrts_enter_frame_1.exit
95
-
96
- slowpath.i: ; preds = %entry
97
- %2 = call %struct.__cilkrts_worker* @__cilkrts_bind_thread_1() #0
98
- br label %__cilkrts_enter_frame_1.exit
99
-
100
- __cilkrts_enter_frame_1.exit: ; preds = %entry, %slowpath.i
101
- %.sink = phi i32 [ 16777344, %slowpath.i ], [ 16777216, %entry ]
102
- %3 = phi %struct.__cilkrts_worker* [ %2, %slowpath.i ], [ %0, %entry ]
103
- %4 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
104
- store volatile i32 %.sink, i32* %4, align 8
105
- %5 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %3, i64 0, i32 9
106
- %6 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %5, align 8
107
- %7 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 2
108
- store volatile %struct.__cilkrts_stack_frame* %6, %struct.__cilkrts_stack_frame** %7, align 8
109
- %8 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 3
110
- store volatile %struct.__cilkrts_worker* %3, %struct.__cilkrts_worker** %8, align 8
111
- store volatile %struct.__cilkrts_stack_frame* %__cilkrts_sf, %struct.__cilkrts_stack_frame** %5, align 8
112
- %9 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 5
113
- br label %loop_body
114
-
115
- loop_body: ; preds = %loop_latch, %__cilkrts_enter_frame_1.exit
116
- %c09 = phi i64 [ 0, %__cilkrts_enter_frame_1.exit ], [ %23, %loop_latch ]
117
- %10 = bitcast [5 x i8*]* %9 to i8*
118
- %11 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
119
- %sunkaddr = getelementptr i8, i8* %11, i64 72
120
- %12 = bitcast i8* %sunkaddr to i32*
121
- %13 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
122
- %sunkaddr16 = getelementptr i8, i8* %13, i64 76
123
- %14 = bitcast i8* %sunkaddr16 to i16*
124
- call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* %12, i16* %14) #0
125
- %15 = call i8* @llvm.frameaddress(i32 0)
126
- %16 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
127
- %sunkaddr17 = getelementptr i8, i8* %16, i64 32
128
- %17 = bitcast i8* %sunkaddr17 to i8**
129
- store volatile i8* %15, i8** %17, align 8
130
- %18 = call i8* @llvm.stacksave()
131
- %19 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
132
- %sunkaddr18 = getelementptr i8, i8* %19, i64 48
133
- %20 = bitcast i8* %sunkaddr18 to i8**
134
- store volatile i8* %18, i8** %20, align 8
135
- %21 = call i32 @llvm.eh.sjlj.setjmp(i8* %10) #3
136
- %22 = icmp eq i32 %21, 0
137
- br i1 %22, label %loop_body.split, label %loop_latch
138
-
139
- loop_body.split: ; preds = %loop_body
140
- call fastcc void @kernel_anon_loop_body2.cilk([24 x float]* %C, i64 %c09, [24 x float]* %B, [24 x float]* %A)
141
- br label %loop_latch
142
-
143
- loop_latch: ; preds = %loop_body.split, %loop_body
144
- %23 = add nuw nsw i64 %c09, 1
145
- %exitcond = icmp eq i64 %23, 40
146
- br i1 %exitcond, label %loop_exit, label %loop_body
147
-
148
- loop_exit: ; preds = %loop_latch
149
- %24 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
150
- %25 = load volatile i32, i32* %24, align 8
151
- %26 = and i32 %25, 2
152
- %27 = icmp eq i32 %26, 0
153
- br i1 %27, label %__cilk_sync.exit, label %cilk.sync.savestate.i
154
-
155
- cilk.sync.savestate.i: ; preds = %loop_exit
156
- %28 = bitcast [5 x i8*]* %9 to i8*
157
- %29 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
158
- %sunkaddr19 = getelementptr i8, i8* %29, i64 16
159
- %30 = bitcast i8* %sunkaddr19 to %struct.__cilkrts_worker**
160
- %31 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %30, align 8
161
- %32 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
162
- %sunkaddr20 = getelementptr i8, i8* %32, i64 72
163
- %33 = bitcast i8* %sunkaddr20 to i32*
164
- %34 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
165
- %sunkaddr21 = getelementptr i8, i8* %34, i64 76
166
- %35 = bitcast i8* %sunkaddr21 to i16*
167
- call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* nonnull %33, i16* nonnull %35) #0
168
- %36 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
169
- %sunkaddr22 = getelementptr i8, i8* %36, i64 32
170
- %37 = bitcast i8* %sunkaddr22 to i8**
171
- store volatile i8* %15, i8** %37, align 8
172
- %38 = call i8* @llvm.stacksave()
173
- %39 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
174
- %sunkaddr23 = getelementptr i8, i8* %39, i64 48
175
- %40 = bitcast i8* %sunkaddr23 to i8**
176
- store volatile i8* %38, i8** %40, align 8
177
- %41 = call i32 @llvm.eh.sjlj.setjmp(i8* nonnull %28) #3
178
- %42 = icmp eq i32 %41, 0
179
- br i1 %42, label %cilk.sync.runtimecall.i, label %cilk.sync.excepting.i
180
-
181
- cilk.sync.runtimecall.i: ; preds = %cilk.sync.savestate.i
182
- call void @__cilkrts_sync(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
183
- br label %__cilk_sync.exit
184
-
185
- cilk.sync.excepting.i: ; preds = %cilk.sync.savestate.i
186
- %43 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
187
- %44 = load volatile i32, i32* %43, align 8
188
- %45 = and i32 %44, 16
189
- %46 = icmp eq i32 %45, 0
190
- br i1 %46, label %__cilk_sync.exit, label %cilk.sync.rethrow.i
191
-
192
- cilk.sync.rethrow.i: ; preds = %cilk.sync.excepting.i
193
- call void @__cilkrts_rethrow(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #4
194
- unreachable
195
-
196
- __cilk_sync.exit: ; preds = %loop_exit, %cilk.sync.runtimecall.i, %cilk.sync.excepting.i
197
- %47 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
198
- %48 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
199
- %sunkaddr24 = getelementptr i8, i8* %48, i64 16
200
- %49 = bitcast i8* %sunkaddr24 to %struct.__cilkrts_worker**
201
- %50 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
202
- %51 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %50, i64 0, i32 12, i32 0
203
- %52 = load i64, i64* %51, align 8
204
- %53 = add i64 %52, 1
205
- store i64 %53, i64* %51, align 8
206
- %54 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
207
- %55 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
208
- %sunkaddr25 = getelementptr i8, i8* %55, i64 8
209
- %56 = bitcast i8* %sunkaddr25 to %struct.__cilkrts_stack_frame**
210
- %57 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %56, align 8
211
- %58 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %54, i64 0, i32 9
212
- store volatile %struct.__cilkrts_stack_frame* %57, %struct.__cilkrts_stack_frame** %58, align 8
213
- store volatile %struct.__cilkrts_stack_frame* null, %struct.__cilkrts_stack_frame** %56, align 8
214
- %59 = load volatile i32, i32* %47, align 8
215
- %60 = icmp eq i32 %59, 16777216
216
- br i1 %60, label %__cilk_parent_epilogue.exit, label %body.i
217
-
218
- body.i: ; preds = %__cilk_sync.exit
219
- call void @__cilkrts_leave_frame(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
220
- br label %__cilk_parent_epilogue.exit
221
-
222
- __cilk_parent_epilogue.exit: ; preds = %__cilk_sync.exit, %body.i
223
- ret void
224
- }
225
- )LLVM" ;
226
- EXPECT_EQ (correct_llvm, toString (mod->getFunction (" kernel_anon" )));
56
+ jit.codegenScop (" kernel_anon" , *scop);
227
57
auto fptr =
228
58
(void (*)(float *, float *, float *))jit.getSymbolAddress (" kernel_anon" );
229
59
0 commit comments