Skip to content

Commit b88844d

Browse files
authored
[SYCLomatic] Update two the YMAL migration rules in pytorch_api.yaml on stream migration(#2772)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 5c69355 commit b88844d

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

clang/test/dpct/pytorch/ATen.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void test_CUDAStream_as_arg() {
5555
}
5656

5757
int main() {
58-
// CHECK: dpct::queue_ptr st = &static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream());
58+
// CHECK: dpct::queue_ptr st = &c10::xpu::getCurrentXPUStream().queue();
5959
cudaStream_t st = 0;
6060

6161
// stream APIs
@@ -66,9 +66,9 @@ int main() {
6666
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(devInd);
6767
auto deviceStream = at::cuda::getCurrentCUDAStream(devInd);
6868

69-
// CHECK: dpct::queue_ptr curr_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(). queue());
69+
// CHECK: dpct::queue_ptr curr_cuda_st = &(c10::xpu::getCurrentXPUStream(). queue());
7070
cudaStream_t curr_cuda_st = at::cuda::getCurrentCUDAStream().stream();
71-
// CHECK: dpct::queue_ptr dev_cuda_st = &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream(devInd). queue());
71+
// CHECK: dpct::queue_ptr dev_cuda_st = &(c10::xpu::getCurrentXPUStream(devInd). queue());
7272
cudaStream_t dev_cuda_st = at::cuda::getCurrentCUDAStream(devInd).stream();
7373

7474
test_CUDAStream_as_arg();
@@ -77,7 +77,7 @@ int main() {
7777
}
7878

7979
// CHECK: void foo2(c10::DeviceGuard device_guard, float *f) try {
80-
// CHECK-NEXT: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream()))));
80+
// CHECK-NEXT: (DPCT_CHECK_ERROR(f = (float *)sycl::malloc_device(4, c10::xpu::getCurrentXPUStream().queue())));
8181
void foo2(at::cuda::CUDAGuard device_guard, float *f) {
8282
C10_CUDA_CHECK(cudaMalloc(&f, 4));
8383
}

clang/test/dpct/pytorch/c10.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,9 @@ int main() {
6363
// CHECK: auto currentStream = c10::xpu::getCurrentXPUStream();
6464
auto currentStream = c10::cuda::getCurrentCUDAStream();
6565

66-
// CHECK: dpct::queue_ptr curr_cuda_st =
67-
// CHECK-NEXT: &static_cast<sycl::queue &>(currentStream.queue());
66+
// CHECK: dpct::queue_ptr curr_cuda_st = &(currentStream.queue());
6867
cudaStream_t curr_cuda_st = currentStream.stream();
69-
// CHECK: curr_cuda_st =
70-
// CHECK-NEXT: &static_cast<sycl::queue &>(c10::xpu::getCurrentXPUStream().queue());
68+
// CHECK: curr_cuda_st = &(c10::xpu::getCurrentXPUStream().queue());
7169
curr_cuda_st = c10::cuda::getCurrentCUDAStream().stream();
7270

7371
// CHECK: auto deviceStream = c10::xpu::getCurrentXPUStream(0);

clang/tools/dpct/extensions/pytorch_api_rules/pytorch_api.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
Out: c10::xpu::XPUStream
8888
Methods:
8989
- In: stream
90-
Out: \&static_cast<sycl::queue &>($method_base queue())
90+
Out: \&($method_base queue())
9191

9292
- Rule: rule_c10_cuda_getCurrentCUDAStream
9393
Kind: API
@@ -165,7 +165,7 @@
165165
Kind: HelperFunction
166166
Priority: Takeover
167167
In: get_in_order_queue
168-
Out: static_cast<sycl::queue&>(c10::xpu::getCurrentXPUStream())
168+
Out: c10::xpu::getCurrentXPUStream().queue()
169169
Includes: [<c10/xpu/XPUStream.h>]
170170

171171
- Rule: rule_THC_THCAtomics_cuh

0 commit comments

Comments
 (0)