Skip to content

Commit e1e2266

Browse files
ezhulenevcopybara-github
authored andcommitted
[xla:ffi] Add support for passing scratch allocator to FFI handlers
PiperOrigin-RevId: 615959071
1 parent a93aaab commit e1e2266

File tree

5 files changed

+65
-14
lines changed

5 files changed

+65
-14
lines changed

xla/ffi/api/c_api_internal.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef XLA_FFI_API_C_API_INTERNAL_H_
1717
#define XLA_FFI_API_C_API_INTERNAL_H_
1818

19+
#include <cstdint>
20+
1921
#include "xla/ffi/api/c_api.h"
2022

2123
// Internal XLA FFI API that gives access to XLA implementation details that
@@ -40,12 +42,23 @@ extern "C" {
4042
// caller.
4143
typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status);
4244

43-
// Returns a pointer to main compute stream (pointer to `se::Stream`). In
45+
// Returns a pointer to main compute stream (`se::Stream` pointer). In
4446
// contrast to public C API which returns a pointer to underlying platform
4547
// stream (i.e. cudaStream_t for CUDA backend), this API returns a pointer to
4648
// StreamExecutor stream which is unsafe to use across dynamic library boundary.
4749
typedef void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx);
4850

51+
// Returns the device ordinal of the device associated with the execution
52+
// context.
53+
typedef int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get(
54+
XLA_FFI_ExecutionContext* ctx);
55+
56+
// Returns a pointer to device memory allocator (`se::DeviceMemoryAllocator`
57+
// pointer) which allows to allocate memory inside a custom call from the same
58+
// allocator as XLA (i.e. it allows to construct scratch memory allocator).
59+
typedef void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(
60+
XLA_FFI_ExecutionContext* ctx);
61+
4962
// Returns a pointer to `xla::HloComputation` if FFI handler has a called
5063
// computation attached to it.
5164
typedef void* XLA_FFI_INTERNAL_CalledComputation_Get(
@@ -60,6 +73,9 @@ typedef void* XLA_FFI_INTERNAL_CalledComputation_Get(
6073
struct XLA_FFI_InternalApi {
6174
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Error_Forward);
6275
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Stream_Get);
76+
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_DeviceOrdinal_Get);
77+
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(
78+
XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get);
6379
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get);
6480
};
6581

xla/ffi/ffi.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ limitations under the License.
3737
#include "xla/runtime/memref_view.h"
3838
#include "xla/status.h"
3939
#include "xla/stream_executor/device_memory.h"
40+
#include "xla/stream_executor/device_memory_allocator.h"
41+
#include "xla/stream_executor/scratch_allocator.h"
4042
#include "xla/stream_executor/stream.h"
4143
#include "xla/types.h" // IWYU pragma: keep
4244
#include "xla/xla_data.pb.h"
@@ -197,6 +199,24 @@ struct CtxDecoding<se::Stream> {
197199
}
198200
};
199201

202+
template <size_t n>
203+
struct CtxDecoding<se::OwningScratchAllocator<n>> {
204+
using Type = se::OwningScratchAllocator<n>;
205+
206+
static std::optional<Type> Decode(const XLA_FFI_Api* api,
207+
XLA_FFI_ExecutionContext* ctx,
208+
DiagnosticEngine&) {
209+
int32_t device_ordinal =
210+
api->internal_api->XLA_FFI_INTERNAL_DeviceOrdinal_Get(ctx);
211+
void* device_allocator =
212+
api->internal_api->XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(ctx);
213+
214+
return se::OwningScratchAllocator<n>(
215+
device_ordinal,
216+
reinterpret_cast<se::DeviceMemoryAllocator*>(device_allocator));
217+
}
218+
};
219+
200220
template <>
201221
struct CtxDecoding<CalledComputation> {
202222
using Type = const HloComputation*;

xla/ffi/ffi_api.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/ffi/ffi_api.h"
1717

1818
#include <cstddef>
19+
#include <cstdint>
1920
#include <string>
2021
#include <string_view>
2122
#include <utility>
@@ -265,6 +266,16 @@ static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) {
265266
return ctx->run_options->stream();
266267
}
267268

269+
static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get(
270+
XLA_FFI_ExecutionContext* ctx) {
271+
return ctx->run_options->device_ordinal();
272+
}
273+
274+
static void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(
275+
XLA_FFI_ExecutionContext* ctx) {
276+
return ctx->run_options->allocator();
277+
}
278+
268279
static void* XLA_FFI_INTERNAL_CalledComputation_Get(
269280
XLA_FFI_ExecutionContext* ctx) {
270281
return const_cast<HloComputation*>(ctx->called_computation);
@@ -277,6 +288,8 @@ static void* XLA_FFI_INTERNAL_CalledComputation_Get(
277288
static XLA_FFI_InternalApi internal_api = {
278289
XLA_FFI_INTERNAL_Error_Forward,
279290
XLA_FFI_INTERNAL_Stream_Get,
291+
XLA_FFI_INTERNAL_DeviceOrdinal_Get,
292+
XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get,
280293
XLA_FFI_INTERNAL_CalledComputation_Get,
281294
};
282295

xla/service/gpu/custom_call_test.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ limitations under the License.
2121
#include <string>
2222
#include <vector>
2323

24-
#include <gmock/gmock.h>
25-
#include <gtest/gtest.h>
26-
#include "absl/algorithm/container.h"
27-
#include "absl/strings/str_format.h"
28-
#include "xla/shape.h"
29-
#include "tsl/platform/statusor.h"
30-
3124
#if GOOGLE_CUDA
3225
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep
3326
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
@@ -38,8 +31,12 @@ limitations under the License.
3831
#define PLATFORM "ROCM"
3932
#endif
4033

34+
#include <gmock/gmock.h>
35+
#include <gtest/gtest.h>
36+
#include "absl/algorithm/container.h"
4137
#include "absl/status/status.h"
4238
#include "absl/strings/str_cat.h"
39+
#include "absl/strings/str_format.h"
4340
#include "xla/client/lib/constants.h"
4441
#include "xla/client/xla_builder.h"
4542
#include "xla/ffi/ffi.h"
@@ -49,12 +46,15 @@ limitations under the License.
4946
#include "xla/hlo/ir/hlo_instructions.h"
5047
#include "xla/service/custom_call_status.h"
5148
#include "xla/service/custom_call_target_registry.h"
49+
#include "xla/shape.h"
5250
#include "xla/shape_util.h"
5351
#include "xla/stream_executor/gpu/gpu_types.h"
52+
#include "xla/stream_executor/scratch_allocator.h"
5453
#include "xla/stream_executor/stream.h"
5554
#include "xla/test_helpers.h"
5655
#include "xla/tests/client_library_test_base.h"
5756
#include "tsl/lib/core/status_test_util.h"
57+
#include "tsl/platform/statusor.h"
5858

5959
#if GOOGLE_CUDA
6060
#define gpuSuccess cudaSuccess
@@ -619,7 +619,8 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) {
619619
//===----------------------------------------------------------------------===//
620620

621621
static absl::Status MemcpyWithCalledComputation(
622-
se::Stream* stream, ffi::BufferBase src, ffi::BufferBase dst,
622+
se::Stream* stream, se::OwningScratchAllocator<> scratch_allocator,
623+
ffi::BufferBase src, ffi::BufferBase dst,
623624
const HloComputation* called_computation) {
624625
if (called_computation == nullptr)
625626
return absl::InternalError("Called computation is not defined");
@@ -637,8 +638,9 @@ XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation,
637638
MemcpyWithCalledComputation,
638639
ffi::Ffi::Bind()
639640
.Ctx<se::Stream>()
640-
.Arg<ffi::BufferBase>() // src
641-
.Arg<ffi::BufferBase>() // dst
641+
.Ctx<se::OwningScratchAllocator<>>() // scratch
642+
.Arg<ffi::BufferBase>() // src
643+
.Arg<ffi::BufferBase>() // dst
642644
.Ctx<ffi::CalledComputation>());
643645

644646
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),

xla/stream_executor/scratch_allocator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class OwningScratchAllocator : public ScratchAllocator {
6565
OwningScratchAllocator(int device_ordinal, DeviceMemoryAllocator* allocator)
6666
: device_ordinal_(device_ordinal), allocator_(allocator) {}
6767

68+
OwningScratchAllocator(OwningScratchAllocator&&) = default;
69+
OwningScratchAllocator& operator=(OwningScratchAllocator&&) = default;
70+
6871
int64_t GetMemoryLimitInBytes() override { return -1; }
6972

7073
absl::StatusOr<DeviceMemory<uint8_t>> AllocateBytes(
@@ -80,9 +83,6 @@ class OwningScratchAllocator : public ScratchAllocator {
8083
int device_ordinal_;
8184
DeviceMemoryAllocator* allocator_;
8285
absl::InlinedVector<OwningDeviceMemory, N> buffers_;
83-
84-
OwningScratchAllocator(const OwningScratchAllocator&) = delete;
85-
void operator=(const OwningScratchAllocator&) = delete;
8686
};
8787

8888
} // namespace stream_executor

0 commit comments

Comments
 (0)