Skip to content

Commit 5e83299

Browse files
pemeliyacopybara-github
authored andcommitted
PR #10521: [ROCM] fixing gpu_pjrt_client crash and improved debug in rocm_driver
Imported from GitHub PR #10521 In this PR I fix xla/pjrt/gpu/se_gpu_pjrt_client.cc which causes a crash during the execution of se_gpu_pjrt_client_test.cc, namely the subtest **StreamExecutorGpuClientTest.CopyRawToHostFullBuffer**. The crash was caused by taking a pointer of a unique_ptr in one function's argument and releasing that pointer in another one (not sure why it worked before: maybe on CUDA the compiler has different behaviour?). But I guess the order of evaluating the function arguments is not defined.. Besides, I also improve gpu_driver debug output. @xla-rotation: can you please have a look ? Copybara import of the project: -- fc43f8e by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: fixing gpu_pjrt_client crash and added improved debug output in rocm_driver -- 68e52b5 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: passing promise by value since the callback function is copied -- 65ab1d4 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: fixing buildifier warnings Merging this change closes #10521 COPYBARA_INTEGRATE_REVIEW=#10521 from ROCm:ci_pjrt_gpu_client_fix 65ab1d4 PiperOrigin-RevId: 615973610
1 parent e1e2266 commit 5e83299

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

xla/pjrt/gpu/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
55
load("@tsl//tsl:tsl.bzl", "internal_visibility")
66
load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library")
77
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
8-
load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
98

109
package(
1110
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@@ -127,7 +126,6 @@ cc_library(
127126
xla_cc_test(
128127
name = "se_gpu_pjrt_client_test",
129128
srcs = if_gpu_is_configured(["se_gpu_pjrt_client_test.cc"]),
130-
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
131129
tags = [
132130
"gpu",
133131
"no_oss",

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,11 @@ PjRtFuture<absl::Status> StreamExecutorGpuClient::CopyRawSubBufferToHost(
594594
/*reference_held=*/false);
595595

596596
auto promise = PjRtFuture<absl::Status>::CreatePromise();
597+
auto stream_ptr = stream.get();
597598
auto callback_status = local_device->ThenExecuteCallback(
598-
stream.get(), [promise, free_sub_range = sub_buffer.release(),
599-
free_stream = stream.release(), local_device]() mutable {
599+
stream_ptr,
600+
[promise, free_stream = stream.release(), local_device]() mutable {
600601
auto stream = std::unique_ptr<se::Stream>(free_stream);
601-
auto sub_range = std::unique_ptr<se::DeviceMemoryBase>(free_sub_range);
602602
local_device->ReturnStreamToPool(std::move(stream));
603603
promise.Set(OkStatus());
604604
});

xla/stream_executor/rocm/rocm_driver.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ struct BitPatternToValue {
15801580
"failed to synchronous memcpy from host to device: Gpu dst: %p;"
15811581
" host src: %p; size: %llu=0x%llx",
15821582
absl::bit_cast<void*>(gpu_dst), host_src, size, size));
1583-
VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes";
1583+
VLOG(2) << "successfully sync memcpy'd h2d of " << size << " bytes";
15841584
return absl::OkStatus();
15851585
}
15861586

@@ -1616,7 +1616,8 @@ struct BitPatternToValue {
16161616
}
16171617
VLOG(2) << "successfully enqueued async memcpy d2h of " << size
16181618
<< " bytes from " << absl::bit_cast<void*>(gpu_src) << " to "
1619-
<< host_dst << " on stream " << stream;
1619+
<< host_dst << " on stream " << stream
1620+
<< " device: " << context->device_ordinal();
16201621
return true;
16211622
}
16221623

@@ -1636,8 +1637,10 @@ struct BitPatternToValue {
16361637
size);
16371638
return false;
16381639
}
1639-
VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes"
1640-
<< " on stream " << stream;
1640+
VLOG(2) << "successfully enqueued async memcpy h2d of " << size
1641+
<< " bytes from " << host_src << " to "
1642+
<< absl::bit_cast<void*>(gpu_dst) << " on stream " << stream
1643+
<< " device: " << context->device_ordinal();
16411644
return true;
16421645
}
16431646

@@ -1664,7 +1667,11 @@ struct BitPatternToValue {
16641667

16651668
return false;
16661669
}
1667-
VLOG(2) << "successfully enqueued async memcpy d2d of " << size << " bytes";
1670+
1671+
VLOG(2) << "successfully enqueued async memcpy d2d of " << size
1672+
<< " bytes from " << absl::bit_cast<void*>(gpu_src) << " to "
1673+
<< absl::bit_cast<void*>(gpu_dst) << " on stream " << stream
1674+
<< " device: " << context->device_ordinal();
16681675
return true;
16691676
}
16701677

0 commit comments

Comments
 (0)