Skip to content

Commit 1a56ee6

Browse files
authored
[SYCLomatic] Add more rules for pytorch API migration (#2740)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 980ac0e commit 1a56ee6

File tree

9 files changed

+95
-26
lines changed

9 files changed

+95
-26
lines changed

clang/lib/DPCT/RulesInclude/InclusionHeaders.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "InclusionHeaders.h"
1010
#include "PreProcessor.h"
11+
#include "UserDefinedRules/UserDefinedRules.h"
12+
#include <optional>
1113

1214
namespace clang {
1315
namespace dpct {
@@ -33,11 +35,10 @@ class LastInclusionLocationUpdater {
3335
bool UpdateNeeded;
3436
};
3537

36-
std::string applyUserDefinedHeader(const std::string &FileName) {
37-
// Apply user-defined rule if needed
38+
std::optional<std::pair<std::string, RulePriority>>
39+
getUserDefinedHeader(const std::string &FileName) {
3840
auto It = MapNames::HeaderRuleMap.find(FileName);
39-
if (It != MapNames::HeaderRuleMap.end() &&
40-
It->second.Priority == RulePriority::Takeover) {
41+
if (It != MapNames::HeaderRuleMap.end()) {
4142
auto &Rule = It->second;
4243
std::string ReplHeaderStr = Rule.Prefix;
4344
llvm::raw_string_ostream OS(ReplHeaderStr);
@@ -54,11 +55,12 @@ std::string applyUserDefinedHeader(const std::string &FileName) {
5455
for (auto &Header : Rule.Includes) {
5556
PrintHeader(Header);
5657
}
57-
PrintHeader(Rule.Out);
58+
if (!Rule.Out.empty())
59+
PrintHeader(Rule.Out);
5860
OS << Rule.Postfix;
59-
return ReplHeaderStr;
61+
return std::make_pair(ReplHeaderStr, Rule.Priority);
6062
}
61-
return "";
63+
return std::nullopt;
6264
}
6365

6466
void insertHeaders(std::shared_ptr<DpctFileInfo> File,
@@ -150,6 +152,15 @@ void IncludesCallbacks::InclusionDirective(
150152
Updater.give_up();
151153
};
152154

155+
// Apply user-defined rule if needed
156+
auto UserDefinedInfo = getUserDefinedHeader(FileName.str());
157+
if (UserDefinedInfo.has_value()) {
158+
if (UserDefinedInfo.value().second == RulePriority::Takeover) {
159+
EmplaceReplacement(std::move(UserDefinedInfo.value().first));
160+
return;
161+
}
162+
}
163+
153164
if (Global.isInAnalysisScope(IncludedFile)) {
154165
IncludeFileMap[IncludedFile] = false;
155166
Global.getIncludingFileSet().insert(IncludedFile);
@@ -199,6 +210,7 @@ void IncludesCallbacks::InclusionDirective(
199210
.getReplacement(DpctGlobalInfo::getContext());
200211
DpctGlobalInfo::getIncludeMapSet().push_back({IncludedFile, Repl});
201212
}
213+
UserDefinedInfo.reset();
202214
}
203215
if (Global.isInRoot(IncludedFile))
204216
return;
@@ -208,11 +220,8 @@ void IncludesCallbacks::InclusionDirective(
208220
!Global.getSourceManager().isWrittenInMainFile(HashLoc))
209221
return;
210222

211-
212-
// Apply user-defined rule if needed
213-
if (auto ReplacedStr = applyUserDefinedHeader(FileName.str());
214-
!ReplacedStr.empty()) {
215-
EmplaceReplacement(std::move(ReplacedStr));
223+
if (UserDefinedInfo.has_value()) {
224+
EmplaceReplacement(std::move(UserDefinedInfo.value().first));
216225
return;
217226
}
218227

clang/test/dpct/pytorch/ATen.cu

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// RUN: cp -r %S/pytorch_inc %T/pytorch/ATen/
55
// RUN: cd %T/pytorch/ATen
66
// RUN: mkdir dpct_out
7-
// RUN: dpct --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
7+
// RUN: dpct --format-range=none --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
88
// RUN: FileCheck --input-file %T/pytorch/ATen/dpct_out/ATen.dp.cpp --match-full-lines %T/pytorch/ATen/src/ATen.cu
99

1010
// CHECK: #include <c10/xpu/XPUStream.h>
@@ -18,6 +18,18 @@
1818
// CHECK-NEXT: #include <c10/util/Half.h>
1919
#include <ATen/cuda/CUDATensorMethods.cuh>
2020

21+
// CHECK: // BEGIN_1
22+
// CHECK-EMPTY:
23+
// CHECK-EMPTY:
24+
// CHECK-NEXT: // END_1
25+
// BEGIN_1
26+
#include <ATen/cuda/Exceptions.h>
27+
#include <THC/THCAtomics.cuh>
28+
// END_1
29+
30+
// CHECK: #include <c10/xpu/XPUMacros.h>
31+
#include <c10/cuda/CUDAMacros.h>
32+
2133
#define AT_CUDA_CHECK(stmt) (stmt)
2234

2335
// CHECK: #define BE_AT_CHECK
@@ -31,20 +43,19 @@ void test_CUDAStream_as_arg() {
3143
dim3 blockSize(8, 8, 1);
3244
void *args[] = {nullptr};
3345

34-
// CHECK: ([&]() {
35-
// CHECK-NEXT: ((sycl::queue *)(c10::xpu::getCurrentXPUStream()))
36-
// CHECK-NEXT: ->parallel_for(sycl::nd_range<3>(gridSize * blockSize, blockSize),
37-
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
38-
// CHECK-NEXT: kernel();
39-
// CHECK-NEXT: });
46+
// CHECK: ([&](){
47+
// CHECK-NEXT: ((sycl::queue*)(c10::xpu::getCurrentXPUStream()))->parallel_for(
48+
// CHECK-NEXT: sycl::nd_range<3>(gridSize * blockSize, blockSize),
49+
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
50+
// CHECK-NEXT: kernel();
51+
// CHECK-NEXT: });
4052
// CHECK-NEXT: return 0;
4153
// CHECK-NEXT: }());
4254
AT_CUDA_CHECK(cudaLaunchKernel((const void *)kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
4355
}
4456

4557
int main() {
46-
// CHECK: dpct::queue_ptr st =
47-
// CHECK-NEXT: &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream());
58+
// CHECK: dpct::queue_ptr st = &static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream());
4859
cudaStream_t st = 0;
4960

5061
// stream APIs
@@ -55,14 +66,18 @@ int main() {
5566
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(devInd);
5667
auto deviceStream = at::cuda::getCurrentCUDAStream(devInd);
5768

58-
// CHECK: dpct::queue_ptr curr_cuda_st =
59-
// CHECK-NEXT: &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream().queue());
69+
// CHECK: dpct::queue_ptr curr_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(). queue());
6070
cudaStream_t curr_cuda_st = at::cuda::getCurrentCUDAStream().stream();
61-
// CHECK: dpct::queue_ptr dev_cuda_st = &static_cast<sycl::queue &>(
62-
// CHECK-NEXT: c10::xpu::getCurrentXPUStream(devInd).queue());
71+
// CHECK: dpct::queue_ptr dev_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(devInd). queue());
6372
cudaStream_t dev_cuda_st = at::cuda::getCurrentCUDAStream(devInd).stream();
6473

6574
test_CUDAStream_as_arg();
6675

6776
return 0;
6877
}
78+
79+
// CHECK: void foo2(c10::DeviceGuard device_guard, float *f) try {
80+
// CHECK-NEXT: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()))));
81+
void foo2(at::cuda::CUDAGuard device_guard, float *f) {
82+
C10_CUDA_CHECK(cudaMalloc(&f, 4));
83+
}

clang/test/dpct/pytorch/pytorch_inc/ATen/Tensor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
namespace at {
33
class Tensor {
44
public:
5-
bool is_cuda();
5+
int get_device() const { return 0; }
6+
bool is_cuda() const { return true; };
67
};
78
} // namespace at

clang/test/dpct/pytorch/pytorch_inc/ATen/cuda/CUDAContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <c10/cuda/CUDAStream.h>
4+
#include <c10/cuda/CUDAGuard.h>
45

56
namespace at {
67
using namespace c10;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
#pragma once
2+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
// RUN: echo "empty command"

clang/test/dpct/pytorch/pytorch_inc/c10/cuda/CUDAGuard.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class optional {
99
} // namespace std
1010

1111
namespace c10 {
12+
using DeviceIndex = int8_t;
1213
class Device {
1314
public:
1415
Device(std::string str) {}
@@ -19,5 +20,15 @@ class OptionalCUDAGuard {
1920
public:
2021
OptionalCUDAGuard(std::optional<c10::Device> device) {}
2122
};
23+
struct CUDAGuard {
24+
explicit CUDAGuard() = delete;
25+
explicit CUDAGuard(DeviceIndex device_index) {}
26+
explicit CUDAGuard(Device device) {}
27+
CUDAGuard(const CUDAGuard&) = delete;
28+
CUDAGuard& operator=(const CUDAGuard&) = delete;
29+
CUDAGuard(CUDAGuard&& other) = delete;
30+
CUDAGuard& operator=(CUDAGuard&& other) = delete;
31+
~CUDAGuard() = default;
32+
};
2233
} // namespace cuda
2334
} // namespace c10

clang/test/dpct/pytorch/pytorch_inc/c10/cuda/CUDAMacros.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
#define C10_CUDA_IMPORT
44
#define C10_CUDA_API
55
#define C10_CUDA_BUILD_MAIN_LIB
6+
#define C10_CUDA_CHECK(EXPR) \
7+
do { \
8+
const cudaError_t __err = EXPR; \
9+
} while (0)

clang/tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,28 @@
168168
In: get_in_order_queue
169169
Out: static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream())
170170
Includes: [<c10/xpu/XPUStream.h>]
171+
172+
- Rule: rule_THC_THCAtomics_cuh
173+
Kind: Header
174+
Priority: Takeover
175+
In: THC/THCAtomics.cuh
176+
Out: |
177+
178+
- Rule: rule_ATen_cuda_Exceptions_h
179+
Kind: Header
180+
Priority: Takeover
181+
In: ATen/cuda/Exceptions.h
182+
Out: |
183+
184+
- Rule: rule_remove_C10_CUDA_CHECK
185+
Kind: Macro
186+
Priority: Takeover
187+
In: C10_CUDA_CHECK
188+
Out: |
189+
190+
- Rule: rule_at_cuda_CUDAGuard
191+
Kind: Type
192+
Priority: Takeover
193+
In: c10::cuda::CUDAGuard
194+
Out: c10::DeviceGuard
195+
Includes: [<c10/core/DeviceGuard.h>]

0 commit comments

Comments
 (0)