Skip to content

Commit 2ab702b

Browse files
authored
[SYCLomatic] Add a rule to migrate CUDAGuard declaration (#2869)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 4f1f330 commit 2ab702b

File tree

3 files changed

+98
-83
lines changed

3 files changed

+98
-83
lines changed

clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,9 @@ void ExprAnalysis::analyzeType(TypeLoc TL, const Expr *CSCE,
12431243
if (Iter != MapNames::TypeNamesMap.end()) {
12441244
HelperFeatureSet.insert(Iter->second->RequestFeature);
12451245
requestHelperFeatureForTypeNames(TyName);
1246+
for (const auto &Include : Iter->second->Includes) {
1247+
DpctGlobalInfo::getInstance().insertHeader(SR.getBegin(), Include);
1248+
}
12461249
} else {
12471250
Iter = MapNamesDNN::CuDNNTypeNamesMap.find(TyName);
12481251
if (Iter != MapNamesDNN::CuDNNTypeNamesMap.end()) {

clang/test/dpct/pytorch/ATen.cu

Lines changed: 89 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,89 @@
1-
// RUN: rm -rf %T/pytorch/ATen
2-
// RUN: mkdir -p %T/pytorch/ATen/src
3-
// RUN: cp %S/ATen.cu %T/pytorch/ATen/src/
4-
// RUN: cp -r %S/pytorch_inc %T/pytorch/ATen/
5-
// RUN: cd %T/pytorch/ATen
6-
// RUN: mkdir dpct_out
7-
// RUN: dpct --format-range=none --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
8-
// RUN: FileCheck --input-file %T/pytorch/ATen/dpct_out/ATen.dp.cpp --match-full-lines %T/pytorch/ATen/src/ATen.cu
9-
10-
// CHECK: #include <c10/xpu/XPUStream.h>
11-
#include <iostream>
12-
// CHECK: #include <ATen/xpu/XPUContext.h>
13-
#include <ATen/cuda/CUDAContext.h>
14-
// CHECK: #include <ATen/core/Tensor.h>
15-
#include <ATen/core/Tensor.h>
16-
17-
// CHECK: #include <ATen/Tensor.h>
18-
// CHECK-NEXT: #include <c10/util/Half.h>
19-
#include <ATen/cuda/CUDATensorMethods.cuh>
20-
21-
// CHECK: // BEGIN_1
22-
// CHECK-EMPTY:
23-
// CHECK-EMPTY:
24-
// CHECK-NEXT: // END_1
25-
// BEGIN_1
26-
#include <ATen/cuda/Exceptions.h>
27-
#include <THC/THCAtomics.cuh>
28-
// END_1
29-
30-
// CHECK: #include <c10/xpu/XPUMacros.h>
31-
#include <c10/cuda/CUDAMacros.h>
32-
33-
#define AT_CUDA_CHECK(stmt) (stmt)
34-
35-
// CHECK: #define BE_AT_CHECK
36-
#define BE_AT_CHECK AT_CUDA_CHECK
37-
38-
39-
__global__ void kernel() {}
40-
41-
void test_CUDAStream_as_arg() {
42-
dim3 gridSize(2, 2, 1);
43-
dim3 blockSize(8, 8, 1);
44-
void *args[] = {nullptr};
45-
46-
// CHECK: ([&](){
47-
// CHECK-NEXT: ((sycl::queue*)(c10::xpu::getCurrentXPUStream()))->parallel_for(
48-
// CHECK-NEXT: sycl::nd_range<3>(gridSize * blockSize, blockSize),
49-
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
50-
// CHECK-NEXT: kernel();
51-
// CHECK-NEXT: });
52-
// CHECK-NEXT: return 0;
53-
// CHECK-NEXT: }());
54-
AT_CUDA_CHECK(cudaLaunchKernel((const void *)kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
55-
}
56-
57-
int main() {
58-
// CHECK: dpct::queue_ptr st = &c10::xpu::getCurrentXPUStream().queue();
59-
cudaStream_t st = 0;
60-
61-
// stream APIs
62-
at::DeviceIndex devInd = 1;
63-
64-
// CHECK: auto currentStream = c10::xpu::getCurrentXPUStream();
65-
auto currentStream = at::cuda::getCurrentCUDAStream();
66-
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(devInd);
67-
auto deviceStream = at::cuda::getCurrentCUDAStream(devInd);
68-
69-
// CHECK: dpct::queue_ptr curr_cuda_st = &(c10::xpu::getCurrentXPUStream(). queue());
70-
cudaStream_t curr_cuda_st = at::cuda::getCurrentCUDAStream().stream();
71-
// CHECK: dpct::queue_ptr dev_cuda_st = &(c10::xpu::getCurrentXPUStream(devInd). queue());
72-
cudaStream_t dev_cuda_st = at::cuda::getCurrentCUDAStream(devInd).stream();
73-
74-
test_CUDAStream_as_arg();
75-
76-
return 0;
77-
}
78-
79-
// CHECK: void foo2(c10::DeviceGuard device_guard, float *f) try {
80-
// CHECK-NEXT: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, c10::xpu::getCurrentXPUStream().queue())));
81-
void foo2(at::cuda::CUDAGuard device_guard, float *f) {
82-
C10_CUDA_CHECK(cudaMalloc(&f, 4));
83-
}
1+
// RUN: rm -rf %T/pytorch/ATen
2+
// RUN: mkdir -p %T/pytorch/ATen/src
3+
// RUN: cp %S/ATen.cu %T/pytorch/ATen/src/
4+
// RUN: cp -r %S/pytorch_inc %T/pytorch/ATen/
5+
// RUN: cd %T/pytorch/ATen
6+
// RUN: mkdir dpct_out
7+
// RUN: dpct --format-range=none --out-root dpct_out %T/pytorch/ATen/src/ATen.cu --extra-arg="-I%T/pytorch/ATen/pytorch_inc" --cuda-include-path="%cuda-path/include" --rule-file=%S/../../../tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml --analysis-scope-path %T/pytorch/ATen/pytorch_inc --analysis-scope-path %T/pytorch/ATen/src --in-root %T/pytorch/ATen/src
8+
// RUN: FileCheck --input-file %T/pytorch/ATen/dpct_out/ATen.dp.cpp --match-full-lines %T/pytorch/ATen/src/ATen.cu
9+
10+
// CHECK: #include <c10/xpu/XPUStream.h>
11+
#include <iostream>
12+
// CHECK: #include <ATen/xpu/XPUContext.h>
13+
#include <ATen/cuda/CUDAContext.h>
14+
// CHECK: #include <ATen/core/Tensor.h>
15+
#include <ATen/core/Tensor.h>
16+
17+
// CHECK: #include <ATen/Tensor.h>
18+
// CHECK-NEXT: #include <c10/util/Half.h>
19+
#include <ATen/cuda/CUDATensorMethods.cuh>
20+
21+
// CHECK: // BEGIN_1
22+
// CHECK-EMPTY:
23+
// CHECK-EMPTY:
24+
// CHECK-NEXT: // END_1
25+
// BEGIN_1
26+
#include <ATen/cuda/Exceptions.h>
27+
#include <THC/THCAtomics.cuh>
28+
// END_1
29+
30+
// CHECK: #include <c10/xpu/XPUMacros.h>
31+
// CHECK: #include <c10/core/DeviceGuard.h>
32+
#include <c10/cuda/CUDAMacros.h>
33+
34+
#define AT_CUDA_CHECK(stmt) (stmt)
35+
36+
// CHECK: #define BE_AT_CHECK
37+
#define BE_AT_CHECK AT_CUDA_CHECK
38+
39+
40+
__global__ void kernel() {}
41+
42+
void test_CUDAStream_as_arg() {
43+
dim3 gridSize(2, 2, 1);
44+
dim3 blockSize(8, 8, 1);
45+
void *args[] = {nullptr};
46+
47+
// CHECK: ([&](){
48+
// CHECK-NEXT: ((sycl::queue*)(c10::xpu::getCurrentXPUStream()))->parallel_for(
49+
// CHECK-NEXT: sycl::nd_range<3>(gridSize * blockSize, blockSize),
50+
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
51+
// CHECK-NEXT: kernel();
52+
// CHECK-NEXT: });
53+
// CHECK-NEXT: return 0;
54+
// CHECK-NEXT: }());
55+
AT_CUDA_CHECK(cudaLaunchKernel((const void *)kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
56+
}
57+
58+
int main() {
59+
// CHECK: dpct::queue_ptr st = &c10::xpu::getCurrentXPUStream().queue();
60+
cudaStream_t st = 0;
61+
62+
// stream APIs
63+
at::DeviceIndex devInd = 1;
64+
65+
// CHECK: auto currentStream = c10::xpu::getCurrentXPUStream();
66+
auto currentStream = at::cuda::getCurrentCUDAStream();
67+
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(devInd);
68+
auto deviceStream = at::cuda::getCurrentCUDAStream(devInd);
69+
70+
// CHECK: dpct::queue_ptr curr_cuda_st = &(c10::xpu::getCurrentXPUStream(). queue());
71+
cudaStream_t curr_cuda_st = at::cuda::getCurrentCUDAStream().stream();
72+
// CHECK: dpct::queue_ptr dev_cuda_st = &(c10::xpu::getCurrentXPUStream(devInd). queue());
73+
cudaStream_t dev_cuda_st = at::cuda::getCurrentCUDAStream(devInd).stream();
74+
75+
test_CUDAStream_as_arg();
76+
77+
return 0;
78+
}
79+
80+
// CHECK: void foo2(c10::DeviceGuard device_guard, float *f) try {
81+
// CHECK-NEXT: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, c10::xpu::getCurrentXPUStream().queue())));
82+
void foo2(at::cuda::CUDAGuard device_guard, float *f) {
83+
C10_CUDA_CHECK(cudaMalloc(&f, 4));
84+
}
85+
86+
void foo3(at::Tensor x) {
87+
// CHECK: c10::DeviceGuard device_guard{c10::Device(at::kXPU, (char)x.get_device())};
88+
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
89+
}

clang/tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,9 @@
192192
In: c10::cuda::CUDAGuard
193193
Out: c10::DeviceGuard
194194
Includes: [<c10/core/DeviceGuard.h>]
195+
196+
- Rule: rule_decl_CUDAGuard_with_tenosr
197+
Kind: PatternRewriter
198+
Priority: Takeover
199+
In: c10::DeviceGuard device_guard{${args}};
200+
Out: c10::DeviceGuard device_guard{c10::Device(at::kXPU, ${args})};

0 commit comments

Comments
 (0)