Skip to content

Commit 5174ed6

Browse files
tyb0807copybara-github
authored andcommitted
[xla:gpu] Add AddressComputationThunk
This thunk wraps the logic to compute dynamic offsets/sizes from dynamic-slice and DUS around some original thunks (e.g. custom call or NCCL thunks) PiperOrigin-RevId: 6159208
1 parent 5f69010 commit 5174ed6

File tree

6 files changed

+601
-0
lines changed

6 files changed

+601
-0
lines changed

xla/service/gpu/runtime/BUILD

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,73 @@ xla_test(
188188
# XLA Thunks Runtime
189189
#===-------------------------------------------------------------------------------------------===//
190190

191+
cc_library(
192+
name = "address_computation_thunk",
193+
srcs = ["address_computation_thunk.cc"],
194+
hdrs = ["address_computation_thunk.h"],
195+
deps = [
196+
":sequential_thunk",
197+
"//xla:shape_util",
198+
"//xla:status",
199+
"//xla:status_macros",
200+
"//xla/service:buffer_assignment",
201+
"//xla/service/gpu:buffer_allocations",
202+
"//xla/service/gpu:ir_emission_utils",
203+
"//xla/service/gpu:thunk",
204+
"//xla/stream_executor",
205+
"//xla/stream_executor:memory_allocation",
206+
"@com_google_absl//absl/base:core_headers",
207+
"@com_google_absl//absl/container:flat_hash_map",
208+
"@com_google_absl//absl/status",
209+
"@com_google_absl//absl/strings:str_format",
210+
"@com_google_absl//absl/synchronization",
211+
"@com_google_absl//absl/types:span",
212+
"@llvm-project//llvm:Support",
213+
"@tsl//tsl/platform:errors",
214+
"@tsl//tsl/platform:statusor",
215+
],
216+
)
217+
218+
xla_test(
219+
name = "address_computation_thunk_test",
220+
srcs = if_gpu_is_configured(["address_computation_thunk_test.cc"]),
221+
backend_tags = {
222+
"gpu_a100": ["config-cuda-only"],
223+
"gpu_v100": ["config-cuda-only"],
224+
},
225+
backends = [
226+
"gpu_a100",
227+
"gpu_v100",
228+
],
229+
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
230+
deps = [
231+
":address_computation_thunk",
232+
":gemm_thunk",
233+
"//xla:shape_util",
234+
"//xla:types",
235+
"//xla/service:buffer_assignment",
236+
"//xla/service:executable",
237+
"//xla/service:platform_util",
238+
"//xla/service/gpu:buffer_allocations",
239+
"//xla/service/gpu:launch_dimensions",
240+
"//xla/service/gpu:matmul_utils",
241+
"//xla/service/gpu:thunk",
242+
"//xla/stream_executor",
243+
"//xla/stream_executor:platform",
244+
"//xla/stream_executor:platform_manager",
245+
"//xla/stream_executor/gpu:gpu_test_kernels",
246+
"//xla/stream_executor/gpu:gpu_types_header",
247+
"@com_google_absl//absl/status:statusor",
248+
"@com_google_absl//absl/strings",
249+
"@tsl//tsl/lib/core:status_test_util",
250+
"@tsl//tsl/platform:statusor",
251+
"@tsl//tsl/platform:test",
252+
"@tsl//tsl/platform:test_main",
253+
] + if_cuda_is_configured([
254+
"@local_config_cuda//cuda:cuda_headers",
255+
]),
256+
)
257+
191258
cc_library(
192259
name = "cholesky_thunk",
193260
srcs = if_gpu_is_configured(["cholesky_thunk.cc"]),
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/runtime/address_computation_thunk.h"
17+
18+
#include <cstdint>
19+
#include <memory>
20+
#include <optional>
21+
#include <utility>
22+
#include <vector>
23+
24+
#include "absl/status/status.h"
25+
#include "absl/strings/str_format.h"
26+
#include "absl/synchronization/mutex.h"
27+
#include "llvm/ADT/STLExtras.h"
28+
#include "xla/service/buffer_assignment.h"
29+
#include "xla/service/gpu/buffer_allocations.h"
30+
#include "xla/service/gpu/ir_emission_utils.h"
31+
#include "xla/service/gpu/runtime/sequential_thunk.h"
32+
#include "xla/service/gpu/thunk.h"
33+
#include "xla/shape.h"
34+
#include "xla/shape_util.h"
35+
#include "xla/status.h"
36+
#include "xla/status_macros.h"
37+
#include "xla/stream_executor/device_memory.h"
38+
#include "xla/stream_executor/memory_allocation.h"
39+
#include "tsl/platform/errors.h"
40+
#include "tsl/platform/statusor.h"
41+
42+
namespace xla {
43+
namespace gpu {
44+
45+
AddressComputationThunk::AddressComputationThunk(
46+
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
47+
std::vector<std::optional<const BufferAllocation::Slice>> operands,
48+
std::vector<std::optional<const BufferAllocation::Slice>> results,
49+
std::vector<std::optional<const BufferAllocation::Slice>>
50+
offset_buffer_indices,
51+
std::vector<std::optional<const Shape>> orig_shapes,
52+
std::vector<std::optional<const Shape>> sliced_shapes)
53+
: Thunk(Kind::kAddressComputation, thunk_info),
54+
embedded_thunk_(std::make_unique<SequentialThunk>(
55+
ThunkInfo(thunk_info.op), std::move(*embedded_thunk))),
56+
embedded_thunk_operands_(std::move(operands)),
57+
embedded_thunk_results_(std::move(results)),
58+
offset_buffer_indices_(std::move(offset_buffer_indices)),
59+
orig_shapes_(std::move(orig_shapes)),
60+
sliced_shapes_(std::move(sliced_shapes)) {}
61+
62+
absl::Status AddressComputationThunk::Prepare(
63+
const PrepareParams& params, ResourceRequests& resource_requests) {
64+
auto num_operands = embedded_thunk_operands_.size();
65+
TF_RET_CHECK(num_operands == offset_buffer_indices_.size());
66+
TF_RET_CHECK(num_operands == orig_shapes_.size());
67+
TF_RET_CHECK(num_operands == sliced_shapes_.size());
68+
for (unsigned i = 0; i < num_operands; ++i) {
69+
if (sliced_shapes_[i].has_value()) {
70+
TF_RET_CHECK(embedded_thunk_operands_[i].has_value());
71+
TF_RET_CHECK(offset_buffer_indices_[i].has_value());
72+
TF_RET_CHECK(sliced_shapes_[i]->IsArray());
73+
TF_RET_CHECK(orig_shapes_[i].has_value() && orig_shapes_[i]->IsArray());
74+
}
75+
}
76+
TF_RETURN_IF_ERROR(embedded_thunk_->Prepare(params, resource_requests));
77+
return absl::OkStatus();
78+
}
79+
80+
absl::Status AddressComputationThunk::Initialize(
81+
const InitializeParams& params) {
82+
TF_RETURN_IF_ERROR(embedded_thunk_->Initialize(params));
83+
84+
unsigned num_offsets = 0;
85+
for (auto maybe_shape : sliced_shapes_) {
86+
num_offsets += (maybe_shape == std::nullopt) ? 1 : maybe_shape->rank();
87+
}
88+
absl::MutexLock lock(&mutex_);
89+
if (auto it = offsets_.find(params.executor); it == offsets_.end()) {
90+
TF_ASSIGN_OR_RETURN(
91+
std::unique_ptr<se::MemoryAllocation> allocation,
92+
params.executor->HostMemoryAllocate(num_offsets * sizeof(int64_t)));
93+
offsets_.emplace(params.executor, std::move(allocation));
94+
}
95+
96+
return absl::OkStatus();
97+
}
98+
99+
absl::Status AddressComputationThunk::ExecuteOnStream(
100+
const ExecuteParams& params) {
101+
auto& stream = *params.stream;
102+
103+
// Get memory allocation for copying offsets from device.
104+
int64_t* offsets_base = [&] {
105+
absl::MutexLock lock(&mutex_);
106+
return reinterpret_cast<int64_t*>(offsets_.at(stream.parent())->opaque());
107+
}();
108+
109+
std::vector<se::DeviceMemoryBase> new_buffers;
110+
const BufferAllocations& orig_allocations = *params.buffer_allocations;
111+
for (unsigned i = 0; i < offset_buffer_indices_.size(); ++i) {
112+
if (embedded_thunk_operands_[i] == std::nullopt) {
113+
new_buffers.push_back(se::DeviceMemoryBase());
114+
continue;
115+
}
116+
117+
se::DeviceMemoryBase orig_operand =
118+
orig_allocations.GetDeviceAddress(*embedded_thunk_operands_[i]);
119+
if (offset_buffer_indices_[i] == std::nullopt) {
120+
new_buffers.push_back(orig_operand);
121+
continue;
122+
}
123+
124+
se::DeviceMemoryBase offset_src =
125+
orig_allocations.GetDeviceAddress(*offset_buffer_indices_[i]);
126+
127+
// Copy the ith offset from device to host.
128+
const Shape& src_shape = *orig_shapes_[i];
129+
const Shape& dst_shape = *sliced_shapes_[i];
130+
int64_t* offset_dst = &offsets_base[i];
131+
TF_RETURN_IF_ERROR(stream.Memcpy(offset_dst, offset_src,
132+
dst_shape.rank() * sizeof(int64_t)));
133+
134+
if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) {
135+
return absl::InternalError(absl::StrFormat(
136+
"Failed to retrieve all slice offset values on stream %p: %s",
137+
&stream, blocked.message()));
138+
}
139+
140+
// Compute new slice. No need to copy the content to new buffers as we can
141+
// reuse the original buffers since slices are contiguous.
142+
TF_RET_CHECK(IsContiguousSlice(src_shape, dst_shape));
143+
144+
int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape);
145+
BufferAllocation::Slice orig_slice = *embedded_thunk_operands_[i];
146+
147+
int64_t new_offset = orig_slice.offset();
148+
std::vector<int64_t> slice_starts(offset_dst,
149+
offset_dst + dst_shape.rank());
150+
for (auto [start, stride] :
151+
llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) {
152+
new_offset += start * stride;
153+
}
154+
155+
new_buffers.push_back(orig_operand.GetByteSlice(new_offset, new_size));
156+
}
157+
158+
// TODO(vuson): handle DUS too. For now just copy the results over.
159+
for (auto result : embedded_thunk_results_) {
160+
if (result == std::nullopt) {
161+
new_buffers.push_back(se::DeviceMemoryBase());
162+
} else {
163+
se::DeviceMemoryBase orig_result =
164+
orig_allocations.GetDeviceAddress(*result);
165+
new_buffers.push_back(orig_result);
166+
}
167+
}
168+
169+
// Safe to create a local BufferAllocations here since buffers are only slices
170+
// of bigger ones allocated elsewhere.
171+
BufferAllocations new_allocations(new_buffers,
172+
orig_allocations.device_ordinal(),
173+
orig_allocations.memory_allocator(),
174+
orig_allocations.external_allocations());
175+
176+
Thunk::ExecuteParams new_params =
177+
Thunk::ExecuteParams::CloneWithNewAllocations(params, new_allocations);
178+
179+
// Execute the underlying custom call thunk with the new buffers.
180+
TF_RETURN_IF_ERROR(embedded_thunk_->ExecuteOnStream(new_params));
181+
182+
return absl::OkStatus();
183+
}
184+
185+
} // namespace gpu
186+
} // namespace xla
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_
17+
#define XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_
18+
19+
#include <cstdint>
20+
#include <memory>
21+
#include <vector>
22+
23+
#include "absl/base/thread_annotations.h"
24+
#include "absl/container/flat_hash_map.h"
25+
#include "absl/status/status.h"
26+
#include "absl/synchronization/mutex.h"
27+
#include "absl/types/span.h"
28+
#include "xla/service/buffer_assignment.h"
29+
#include "xla/service/gpu/runtime/sequential_thunk.h"
30+
#include "xla/service/gpu/thunk.h"
31+
#include "xla/status.h"
32+
#include "xla/stream_executor/memory_allocation.h"
33+
#include "xla/stream_executor/stream_executor.h"
34+
35+
namespace xla {
36+
namespace gpu {
37+
38+
// AddressComputationThunk wraps the logic to compute dynamic offsets/sizes from
39+
// dynamic-slice or DUS around some original thunks (e.g. custom call or NCCL
40+
// thunks)
41+
//
42+
// AddressComputationThunk assumes that the slices are contiguous.
43+
class AddressComputationThunk : public Thunk {
44+
public:
45+
AddressComputationThunk(
46+
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
47+
std::vector<std::optional<const BufferAllocation::Slice>> operands,
48+
std::vector<std::optional<const BufferAllocation::Slice>> results,
49+
std::vector<std::optional<const BufferAllocation::Slice>>
50+
offset_buffer_indices,
51+
std::vector<std::optional<const Shape>> orig_shapes,
52+
std::vector<std::optional<const Shape>> sliced_shapes);
53+
54+
AddressComputationThunk(const AddressComputationThunk&) = delete;
55+
AddressComputationThunk& operator=(const AddressComputationThunk&) = delete;
56+
57+
absl::Status Prepare(const PrepareParams& params,
58+
ResourceRequests& resource_requests) override;
59+
absl::Status Initialize(const InitializeParams& params) override;
60+
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
61+
62+
private:
63+
std::unique_ptr<SequentialThunk> embedded_thunk_;
64+
std::vector<std::optional<const BufferAllocation::Slice>>
65+
embedded_thunk_operands_;
66+
std::vector<std::optional<const BufferAllocation::Slice>>
67+
embedded_thunk_results_;
68+
std::vector<std::optional<const BufferAllocation::Slice>>
69+
offset_buffer_indices_;
70+
71+
std::vector<std::optional<const Shape>> orig_shapes_;
72+
std::vector<std::optional<const Shape>> sliced_shapes_;
73+
74+
// Pinned host memory for transferring offset values from device to host.
75+
absl::Mutex mutex_;
76+
absl::flat_hash_map<se::StreamExecutor*,
77+
std::unique_ptr<se::MemoryAllocation>>
78+
offsets_ ABSL_GUARDED_BY(mutex_);
79+
};
80+
81+
} // namespace gpu
82+
} // namespace xla
83+
84+
#endif // XLA_SERVICE_GPU_RUNTIME_ADDRESS_COMPUTATION_THUNK_H_

0 commit comments

Comments
 (0)