Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 6be0ca0

Browse files
committed
Resolve comments
1 parent f690087 commit 6be0ca0

File tree

4 files changed

+114
-172
lines changed

4 files changed

+114
-172
lines changed

build.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ else
1717
fi
1818
WITH_PYTHON_C2=${WITH_PYTHON_C2:=OFF}
1919
WITH_NNPACK=${WITH_NNPACK:=OFF}
20+
WITH_TAPIR=${WITH_TAPIR:=ON}
2021
PYTHON=${PYTHON:="`which python3`"}
2122
PROTOC=${PROTOC:="`which protoc`"}
2223
CORES=${CORES:=32}
@@ -401,6 +402,7 @@ function install_tc() {
401402
rm -rf *
402403
VERBOSE=${VERBOSE} ${CMAKE_VERSION} -DWITH_CAFFE2=${WITH_CAFFE2} \
403404
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
405+
-DWITH_TAPIR=${WITH_TAPIR} \
404406
-DPYTHON_EXECUTABLE=${PYTHON} \
405407
-DHALIDE_PREFIX=${INSTALL_PREFIX} \
406408
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \

test/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ foreach(i ${CORE_TEST_FILES})
100100
target_link_libraries(${i} ${GOOGLE_LIBS} tc_core_cuda_no_sdk)
101101
endforeach()
102102

103-
104103
add_executable(test_mapper_llvm test_mapper_llvm.cc)
105104
add_test(test_mapper_llvm test_mapper_llvm)
106105
target_link_libraries(
@@ -112,6 +111,19 @@ target_link_libraries(
112111

113112
tc_core_cpu tc_lang)
114113

114+
if (WITH_TAPIR)
115+
add_executable(test_mapper_tapir test_mapper_tapir.cc)
116+
add_test(test_mapper_tapir test_mapper_tapir)
117+
target_link_libraries(
118+
test_mapper_tapir
119+
120+
${GOOGLE_LIBS}
121+
${ATEN_LIBRARIES}
122+
-lLLVM
123+
124+
tc_core_cpu tc_lang)
125+
endif()
126+
115127
################################################################################
116128
# TensorComprehensions tests
117129
# No real need for NVCC if we only use NVRTC

test/test_mapper_llvm.cc

Lines changed: 1 addition & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -51,179 +51,9 @@ def fun(float(N, M) A, float(N, M) B) -> (C) {
5151
auto context = scop->makeContext(
5252
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
5353
scop = Scop::makeSpecializedScop(*scop, context);
54-
Jit jit;
55-
jit.codegenScop("kernel_anon", *scop);
56-
auto fptr =
57-
(void (*)(float*, float*, float*))jit.getSymbolAddress("kernel_anon");
58-
59-
at::Tensor A = at::CPU(at::kFloat).rand({N, M});
60-
at::Tensor B = at::CPU(at::kFloat).rand({N, M});
61-
at::Tensor C = at::CPU(at::kFloat).rand({N, M});
62-
at::Tensor Cc = A + B;
63-
fptr(A.data<float>(), B.data<float>(), C.data<float>());
6454

65-
checkRtol(Cc - C, {A, B}, N * M);
66-
}
67-
68-
TEST(LLVMCodegen, BasicParallel) {
69-
string tc = R"TC(
70-
def fun(float(N, M) A, float(N, M) B) -> (C) {
71-
C(n, m) = A(n, m) + B(n, m)
72-
}
73-
)TC";
74-
auto N = 40;
75-
auto M = 24;
76-
77-
auto ctx = isl::with_exceptions::globalIslCtx();
78-
auto scop = polyhedral::Scop::makeScop(ctx, tc);
79-
auto context = scop->makeContext(
80-
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
81-
scop = Scop::makeSpecializedScop(*scop, context);
82-
SchedulerOptionsProto sop;
83-
SchedulerOptionsView sov(sop);
84-
scop = Scop::makeScheduled(*scop, sov);
8555
Jit jit;
86-
auto mod = jit.codegenScop("kernel_anon", *scop);
87-
auto correct_llvm = R"LLVM(
88-
; Function Attrs: nounwind
89-
define void @kernel_anon([24 x float]* noalias nocapture nonnull readonly %A, [24 x float]* noalias nocapture nonnull readonly %B, [24 x float]* noalias nocapture nonnull %C) local_unnamed_addr #0 {
90-
entry:
91-
%__cilkrts_sf = alloca %struct.__cilkrts_stack_frame, align 8
92-
%0 = call %struct.__cilkrts_worker* @__cilkrts_get_tls_worker() #0
93-
%1 = icmp eq %struct.__cilkrts_worker* %0, null
94-
br i1 %1, label %slowpath.i, label %__cilkrts_enter_frame_1.exit
95-
96-
slowpath.i: ; preds = %entry
97-
%2 = call %struct.__cilkrts_worker* @__cilkrts_bind_thread_1() #0
98-
br label %__cilkrts_enter_frame_1.exit
99-
100-
__cilkrts_enter_frame_1.exit: ; preds = %entry, %slowpath.i
101-
%.sink = phi i32 [ 16777344, %slowpath.i ], [ 16777216, %entry ]
102-
%3 = phi %struct.__cilkrts_worker* [ %2, %slowpath.i ], [ %0, %entry ]
103-
%4 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
104-
store volatile i32 %.sink, i32* %4, align 8
105-
%5 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %3, i64 0, i32 9
106-
%6 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %5, align 8
107-
%7 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 2
108-
store volatile %struct.__cilkrts_stack_frame* %6, %struct.__cilkrts_stack_frame** %7, align 8
109-
%8 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 3
110-
store volatile %struct.__cilkrts_worker* %3, %struct.__cilkrts_worker** %8, align 8
111-
store volatile %struct.__cilkrts_stack_frame* %__cilkrts_sf, %struct.__cilkrts_stack_frame** %5, align 8
112-
%9 = getelementptr inbounds %struct.__cilkrts_stack_frame, %struct.__cilkrts_stack_frame* %__cilkrts_sf, i64 0, i32 5
113-
br label %loop_body
114-
115-
loop_body: ; preds = %loop_latch, %__cilkrts_enter_frame_1.exit
116-
%c09 = phi i64 [ 0, %__cilkrts_enter_frame_1.exit ], [ %23, %loop_latch ]
117-
%10 = bitcast [5 x i8*]* %9 to i8*
118-
%11 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
119-
%sunkaddr = getelementptr i8, i8* %11, i64 72
120-
%12 = bitcast i8* %sunkaddr to i32*
121-
%13 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
122-
%sunkaddr16 = getelementptr i8, i8* %13, i64 76
123-
%14 = bitcast i8* %sunkaddr16 to i16*
124-
call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* %12, i16* %14) #0
125-
%15 = call i8* @llvm.frameaddress(i32 0)
126-
%16 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
127-
%sunkaddr17 = getelementptr i8, i8* %16, i64 32
128-
%17 = bitcast i8* %sunkaddr17 to i8**
129-
store volatile i8* %15, i8** %17, align 8
130-
%18 = call i8* @llvm.stacksave()
131-
%19 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
132-
%sunkaddr18 = getelementptr i8, i8* %19, i64 48
133-
%20 = bitcast i8* %sunkaddr18 to i8**
134-
store volatile i8* %18, i8** %20, align 8
135-
%21 = call i32 @llvm.eh.sjlj.setjmp(i8* %10) #3
136-
%22 = icmp eq i32 %21, 0
137-
br i1 %22, label %loop_body.split, label %loop_latch
138-
139-
loop_body.split: ; preds = %loop_body
140-
call fastcc void @kernel_anon_loop_body2.cilk([24 x float]* %C, i64 %c09, [24 x float]* %B, [24 x float]* %A)
141-
br label %loop_latch
142-
143-
loop_latch: ; preds = %loop_body.split, %loop_body
144-
%23 = add nuw nsw i64 %c09, 1
145-
%exitcond = icmp eq i64 %23, 40
146-
br i1 %exitcond, label %loop_exit, label %loop_body
147-
148-
loop_exit: ; preds = %loop_latch
149-
%24 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
150-
%25 = load volatile i32, i32* %24, align 8
151-
%26 = and i32 %25, 2
152-
%27 = icmp eq i32 %26, 0
153-
br i1 %27, label %__cilk_sync.exit, label %cilk.sync.savestate.i
154-
155-
cilk.sync.savestate.i: ; preds = %loop_exit
156-
%28 = bitcast [5 x i8*]* %9 to i8*
157-
%29 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
158-
%sunkaddr19 = getelementptr i8, i8* %29, i64 16
159-
%30 = bitcast i8* %sunkaddr19 to %struct.__cilkrts_worker**
160-
%31 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %30, align 8
161-
%32 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
162-
%sunkaddr20 = getelementptr i8, i8* %32, i64 72
163-
%33 = bitcast i8* %sunkaddr20 to i32*
164-
%34 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
165-
%sunkaddr21 = getelementptr i8, i8* %34, i64 76
166-
%35 = bitcast i8* %sunkaddr21 to i16*
167-
call void asm sideeffect "stmxcsr $0\0A\09fnstcw $1", "*m,*m,~{dirflag},~{fpsr},~{flags}"(i32* nonnull %33, i16* nonnull %35) #0
168-
%36 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
169-
%sunkaddr22 = getelementptr i8, i8* %36, i64 32
170-
%37 = bitcast i8* %sunkaddr22 to i8**
171-
store volatile i8* %15, i8** %37, align 8
172-
%38 = call i8* @llvm.stacksave()
173-
%39 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
174-
%sunkaddr23 = getelementptr i8, i8* %39, i64 48
175-
%40 = bitcast i8* %sunkaddr23 to i8**
176-
store volatile i8* %38, i8** %40, align 8
177-
%41 = call i32 @llvm.eh.sjlj.setjmp(i8* nonnull %28) #3
178-
%42 = icmp eq i32 %41, 0
179-
br i1 %42, label %cilk.sync.runtimecall.i, label %cilk.sync.excepting.i
180-
181-
cilk.sync.runtimecall.i: ; preds = %cilk.sync.savestate.i
182-
call void @__cilkrts_sync(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
183-
br label %__cilk_sync.exit
184-
185-
cilk.sync.excepting.i: ; preds = %cilk.sync.savestate.i
186-
%43 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
187-
%44 = load volatile i32, i32* %43, align 8
188-
%45 = and i32 %44, 16
189-
%46 = icmp eq i32 %45, 0
190-
br i1 %46, label %__cilk_sync.exit, label %cilk.sync.rethrow.i
191-
192-
cilk.sync.rethrow.i: ; preds = %cilk.sync.excepting.i
193-
call void @__cilkrts_rethrow(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #4
194-
unreachable
195-
196-
__cilk_sync.exit: ; preds = %loop_exit, %cilk.sync.runtimecall.i, %cilk.sync.excepting.i
197-
%47 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i32*
198-
%48 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
199-
%sunkaddr24 = getelementptr i8, i8* %48, i64 16
200-
%49 = bitcast i8* %sunkaddr24 to %struct.__cilkrts_worker**
201-
%50 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
202-
%51 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %50, i64 0, i32 12, i32 0
203-
%52 = load i64, i64* %51, align 8
204-
%53 = add i64 %52, 1
205-
store i64 %53, i64* %51, align 8
206-
%54 = load volatile %struct.__cilkrts_worker*, %struct.__cilkrts_worker** %49, align 8
207-
%55 = bitcast %struct.__cilkrts_stack_frame* %__cilkrts_sf to i8*
208-
%sunkaddr25 = getelementptr i8, i8* %55, i64 8
209-
%56 = bitcast i8* %sunkaddr25 to %struct.__cilkrts_stack_frame**
210-
%57 = load volatile %struct.__cilkrts_stack_frame*, %struct.__cilkrts_stack_frame** %56, align 8
211-
%58 = getelementptr inbounds %struct.__cilkrts_worker, %struct.__cilkrts_worker* %54, i64 0, i32 9
212-
store volatile %struct.__cilkrts_stack_frame* %57, %struct.__cilkrts_stack_frame** %58, align 8
213-
store volatile %struct.__cilkrts_stack_frame* null, %struct.__cilkrts_stack_frame** %56, align 8
214-
%59 = load volatile i32, i32* %47, align 8
215-
%60 = icmp eq i32 %59, 16777216
216-
br i1 %60, label %__cilk_parent_epilogue.exit, label %body.i
217-
218-
body.i: ; preds = %__cilk_sync.exit
219-
call void @__cilkrts_leave_frame(%struct.__cilkrts_stack_frame* nonnull %__cilkrts_sf) #0
220-
br label %__cilk_parent_epilogue.exit
221-
222-
__cilk_parent_epilogue.exit: ; preds = %__cilk_sync.exit, %body.i
223-
ret void
224-
}
225-
)LLVM";
226-
EXPECT_EQ(correct_llvm, toString(mod->getFunction("kernel_anon")));
56+
jit.codegenScop("kernel_anon", *scop);
22757
auto fptr =
22858
(void (*)(float*, float*, float*))jit.getSymbolAddress("kernel_anon");
22959

test/test_mapper_tapir.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <gflags/gflags.h>
18+
#include <glog/logging.h>
19+
#include <gtest/gtest.h>
20+
21+
#include <ATen/ATen.h>
22+
23+
#include <llvm/IR/InstIterator.h>
24+
#include <llvm/IR/Instructions.h>
25+
26+
#include "tc/aten/utils.h"
27+
#include "tc/core/cpu/cpu_tc_executor.h"
28+
#include "tc/core/execution_engine.h"
29+
#include "tc/core/mapping_options.h"
30+
#include "tc/core/polyhedral/codegen_llvm.h"
31+
#include "tc/core/polyhedral/llvm_jit.h"
32+
#include "tc/core/polyhedral/scop.h"
33+
#include "tc/core/scope_guard.h"
34+
35+
#include "test_harness_aten.h"
36+
37+
using namespace std;
38+
39+
using namespace tc;
40+
using namespace tc::polyhedral;
41+
using namespace tc::polyhedral::detail;
42+
43+
TEST(TapirCodegen, BasicParallel) {
44+
string tc = R"TC(
45+
def fun(float(N, M) A, float(N, M) B) -> (C) {
46+
C(n, m) = A(n, m) + B(n, m)
47+
}
48+
)TC";
49+
auto N = 40;
50+
auto M = 24;
51+
52+
auto ctx = isl::with_exceptions::globalIslCtx();
53+
auto scop = polyhedral::Scop::makeScop(ctx, tc);
54+
auto context = scop->makeContext(
55+
std::unordered_map<std::string, int>{{"N", N}, {"M", M}});
56+
scop = Scop::makeSpecializedScop(*scop, context);
57+
SchedulerOptionsProto sop;
58+
SchedulerOptionsView sov(sop);
59+
scop = Scop::makeScheduled(*scop, sov);
60+
Jit jit;
61+
auto mod = jit.codegenScop("kernel_anon", *scop);
62+
auto fn = mod->getFunction("kernel_anon");
63+
64+
std::set<string> calledFunctions;
65+
for (llvm::inst_iterator I = llvm::inst_begin(fn), E = llvm::inst_end(fn);
66+
I != E;
67+
++I) {
68+
if (llvm::CallInst* c = llvm::dyn_cast<llvm::CallInst>(&*I)) {
69+
if (auto called = c->getCalledFunction()) {
70+
calledFunctions.insert(called->getName());
71+
}
72+
}
73+
}
74+
75+
ASSERT_NE(0, calledFunctions.count("__cilkrts_get_tls_worker"));
76+
ASSERT_NE(0, calledFunctions.count("__cilkrts_bind_thread_1"));
77+
ASSERT_NE(0, calledFunctions.count("llvm.stacksave"));
78+
ASSERT_NE(0, calledFunctions.count("__cilkrts_sync"));
79+
80+
auto fptr =
81+
(void (*)(float*, float*, float*))jit.getSymbolAddress("kernel_anon");
82+
83+
at::Tensor A = at::CPU(at::kFloat).rand({N, M});
84+
at::Tensor B = at::CPU(at::kFloat).rand({N, M});
85+
at::Tensor C = at::CPU(at::kFloat).rand({N, M});
86+
at::Tensor Cc = A + B;
87+
fptr(A.data<float>(), B.data<float>(), C.data<float>());
88+
89+
checkRtol(Cc - C, {A, B}, N * M);
90+
}
91+
92+
int main(int argc, char** argv) {
93+
::testing::InitGoogleTest(&argc, argv);
94+
::gflags::ParseCommandLineFlags(&argc, &argv, true);
95+
::google::InitGoogleLogging(argv[0]);
96+
initialize_llvm();
97+
return RUN_ALL_TESTS();
98+
}

0 commit comments

Comments
 (0)