Skip to content

Commit a93aaab

Browse files
ezhulenevcopybara-github
authored andcommitted
[xla:ffi] Add access to se::Stream to internal FFI API and remove ServiceExecutableOptions
PiperOrigin-RevId: 615956789
1 parent 5174ed6 commit a93aaab

File tree

10 files changed

+78
-72
lines changed

10 files changed

+78
-72
lines changed

xla/ffi/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ cc_library(
4343
"//xla/ffi/api:c_api_internal",
4444
"//xla/hlo/ir:hlo",
4545
"//xla/runtime:memref_view",
46-
"//xla/service:executable",
46+
"//xla/stream_executor",
4747
"//xla/stream_executor:device_memory",
4848
"@com_google_absl//absl/types:span",
4949
],
@@ -78,6 +78,7 @@ xla_cc_test(
7878
":ffi_api",
7979
"//xla:xla_data_proto_cc",
8080
"//xla/service:executable",
81+
"//xla/stream_executor",
8182
"//xla/stream_executor:device_memory",
8283
"@com_google_absl//absl/status",
8384
"@tsl//tsl/lib/core:status_test_util",

xla/ffi/api/c_api_internal.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@ extern "C" {
3838
// Forwards `absl::Status` object pointed to by `status` to XLA FFI error
3939
// (status left in moved-from state). Pointer ownership stays with the
4040
// caller.
41-
typedef XLA_FFI_Error* XLA_FFI_Error_Forward(void* status);
41+
typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status);
4242

43-
// Returns a pointer to `xla::ServiceExecutableRunOptions`.
44-
typedef void* XLA_FFI_ServiceExecutableRunOptions_Get(
45-
XLA_FFI_ExecutionContext* ctx);
43+
// Returns a pointer to main compute stream (pointer to `se::Stream`). In
44+
// contrast to public C API which returns a pointer to underlying platform
45+
// stream (i.e. cudaStream_t for CUDA backend), this API returns a pointer to
46+
// StreamExecutor stream which is unsafe to use across dynamic library boundary.
47+
typedef void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx);
4648

4749
// Returns a pointer to `xla::HloComputation` if FFI handler has a called
4850
// computation attached to it.
49-
typedef void* XLA_FFI_CalledComputation_Get(XLA_FFI_ExecutionContext* ctx);
51+
typedef void* XLA_FFI_INTERNAL_CalledComputation_Get(
52+
XLA_FFI_ExecutionContext* ctx);
5053

5154
//===----------------------------------------------------------------------===//
5255
// API access
@@ -55,9 +58,9 @@ typedef void* XLA_FFI_CalledComputation_Get(XLA_FFI_ExecutionContext* ctx);
5558
#define _XLA_FFI_INTERNAL_API_STRUCT_FIELD(fn_type) fn_type* fn_type
5659

5760
struct XLA_FFI_InternalApi {
58-
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_Error_Forward);
59-
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_ServiceExecutableRunOptions_Get);
60-
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_CalledComputation_Get);
61+
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Error_Forward);
62+
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_Stream_Get);
63+
_XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get);
6164
};
6265

6366
#undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD

xla/ffi/ffi.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ limitations under the License.
3535
#include "xla/hlo/ir/hlo_computation.h"
3636
#include "xla/primitive_util.h"
3737
#include "xla/runtime/memref_view.h"
38-
#include "xla/service/service_executable_run_options.h"
3938
#include "xla/status.h"
4039
#include "xla/stream_executor/device_memory.h"
40+
#include "xla/stream_executor/stream.h"
4141
#include "xla/types.h" // IWYU pragma: keep
4242
#include "xla/xla_data.pb.h"
4343

@@ -185,16 +185,14 @@ struct AttrDecoding<Pointer<T>> {
185185
// Context decoding
186186
//===----------------------------------------------------------------------===//
187187

188-
// TODO(ezhulenev): We should remove `ServiceExecutableRunOptions` context and
189-
// pass only se::Stream to FFI handlers.
190188
template <>
191-
struct CtxDecoding<ServiceExecutableRunOptions> {
192-
using Type = const ServiceExecutableRunOptions*;
189+
struct CtxDecoding<se::Stream> {
190+
using Type = se::Stream*;
193191

194192
static std::optional<Type> Decode(const XLA_FFI_Api* api,
195193
XLA_FFI_ExecutionContext* ctx,
196194
DiagnosticEngine&) {
197-
void* ptr = api->internal_api->XLA_FFI_ServiceExecutableRunOptions_Get(ctx);
195+
void* ptr = api->internal_api->XLA_FFI_INTERNAL_Stream_Get(ctx);
198196
return reinterpret_cast<Type>(ptr);
199197
}
200198
};
@@ -206,7 +204,7 @@ struct CtxDecoding<CalledComputation> {
206204
static std::optional<Type> Decode(const XLA_FFI_Api* api,
207205
XLA_FFI_ExecutionContext* ctx,
208206
DiagnosticEngine&) {
209-
void* ptr = api->internal_api->XLA_FFI_CalledComputation_Get(ctx);
207+
void* ptr = api->internal_api->XLA_FFI_INTERNAL_CalledComputation_Get(ctx);
210208
return reinterpret_cast<Type>(ptr);
211209
}
212210
};
@@ -218,7 +216,7 @@ struct CtxDecoding<CalledComputation> {
218216
template <>
219217
struct ResultEncoding<Status> {
220218
static XLA_FFI_Error* Encode(XLA_FFI_Api* api, Status status) {
221-
return api->internal_api->XLA_FFI_Error_Forward(&status);
219+
return api->internal_api->XLA_FFI_INTERNAL_Error_Forward(&status);
222220
}
223221
};
224222

xla/ffi/ffi_api.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ limitations under the License.
3131
#include "xla/hlo/ir/hlo_computation.h"
3232
#include "xla/service/service_executable_run_options.h"
3333
#include "xla/status.h"
34-
#include "xla/statusor.h"
35-
#include "tsl/platform/logging.h"
3634

3735
//===----------------------------------------------------------------------===//
3836
// XLA FFI C structs definition
@@ -259,16 +257,16 @@ static XLA_FFI_Error* XLA_FFI_Stream_Get(XLA_FFI_Stream_Get_Args* args) {
259257
// XLA FFI Internal Api Implementation
260258
//===----------------------------------------------------------------------===//
261259

262-
static XLA_FFI_Error* XLA_FFI_Error_Forward(void* status) {
260+
static XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status) {
263261
return new XLA_FFI_Error{std::move(*reinterpret_cast<Status*>(status))};
264262
}
265263

266-
static void* XLA_FFI_ServiceExecutableRunOptions_Get(
267-
XLA_FFI_ExecutionContext* ctx) {
268-
return const_cast<ServiceExecutableRunOptions*>(ctx->run_options);
264+
static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) {
265+
return ctx->run_options->stream();
269266
}
270267

271-
static void* XLA_FFI_CalledComputation_Get(XLA_FFI_ExecutionContext* ctx) {
268+
static void* XLA_FFI_INTERNAL_CalledComputation_Get(
269+
XLA_FFI_ExecutionContext* ctx) {
272270
return const_cast<HloComputation*>(ctx->called_computation);
273271
}
274272

@@ -277,9 +275,9 @@ static void* XLA_FFI_CalledComputation_Get(XLA_FFI_ExecutionContext* ctx) {
277275
//===----------------------------------------------------------------------===//
278276

279277
static XLA_FFI_InternalApi internal_api = {
280-
XLA_FFI_Error_Forward,
281-
XLA_FFI_ServiceExecutableRunOptions_Get,
282-
XLA_FFI_CalledComputation_Get,
278+
XLA_FFI_INTERNAL_Error_Forward,
279+
XLA_FFI_INTERNAL_Stream_Get,
280+
XLA_FFI_INTERNAL_CalledComputation_Get,
283281
};
284282

285283
static XLA_FFI_Api api = {

xla/ffi/ffi_test.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "xla/ffi/ffi_api.h"
2626
#include "xla/service/service_executable_run_options.h"
2727
#include "xla/stream_executor/device_memory.h"
28+
#include "xla/stream_executor/stream.h"
2829
#include "xla/xla_data.pb.h"
2930
#include "tsl/lib/core/status_test_util.h"
3031
#include "tsl/platform/status_matchers.h"
@@ -466,15 +467,18 @@ TEST(FfiTest, RemainingArgs) {
466467

467468
TEST(FfiTest, RunOptionsCtx) {
468469
auto call_frame = CallFrameBuilder().Build();
469-
auto* expected = reinterpret_cast<ServiceExecutableRunOptions*>(0x01234567);
470+
auto* expected = reinterpret_cast<se::Stream*>(0x01234567);
470471

471-
auto fn = [&](const ServiceExecutableRunOptions* run_options) {
472+
ServiceExecutableRunOptions opts;
473+
opts.mutable_run_options()->set_stream(expected);
474+
475+
auto fn = [&](const se::Stream* run_options) {
472476
EXPECT_EQ(run_options, expected);
473477
return absl::OkStatus();
474478
};
475479

476-
auto handler = Ffi::Bind().Ctx<ServiceExecutableRunOptions>().To(fn);
477-
auto status = Call(*handler, call_frame, {expected});
480+
auto handler = Ffi::Bind().Ctx<se::Stream>().To(fn);
481+
auto status = Call(*handler, call_frame, {&opts});
478482

479483
TF_ASSERT_OK(status);
480484
}

xla/service/gpu/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ xla_cc_test(
177177
"//xla/service:custom_call_target_registry",
178178
"//xla/service:executable",
179179
"//xla/service:gpu_plugin",
180+
"//xla/stream_executor",
180181
"//xla/stream_executor/gpu:gpu_types_header",
181182
"//xla/tests:client_library_test_base",
182183
"//xla/tests:xla_internal_test_main", # fixdeps: keep
@@ -3552,6 +3553,7 @@ xla_cc_test(
35523553
"//xla/service:executable",
35533554
"//xla/service:hlo_memory_scheduler",
35543555
"//xla/service:hlo_module_config",
3556+
"//xla/stream_executor",
35553557
"//xla/stream_executor/gpu:gpu_types_header",
35563558
"//xla/tests:hlo_test_base",
35573559
"@com_google_absl//absl/algorithm:container",

xla/service/gpu/address_computation_fusion_rewriter_test.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ limitations under the License.
3838
#include "xla/shape.h"
3939
#include "xla/shape_util.h"
4040
#include "xla/stream_executor/gpu/gpu_types.h"
41+
#include "xla/stream_executor/stream.h"
4142
#include "xla/tests/hlo_test_base.h"
4243
#include "tsl/platform/status.h"
4344
#include "tsl/platform/statusor.h"
@@ -869,17 +870,17 @@ TEST_F(AddressComputationFusionRewriterTest, SimpleGemmOperandsFromSameSlice) {
869870
});
870871
}
871872

872-
static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options,
873-
ffi::BufferBase src, ffi::BufferBase dst) {
874-
return run_options->stream()->MemcpyD2D(
873+
static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src,
874+
ffi::BufferBase dst) {
875+
return stream->MemcpyD2D(
875876
&dst.data, src.data,
876877
absl::c_accumulate(src.dimensions, 1.0, std::multiplies<int64_t>()) *
877878
sizeof(float));
878879
}
879880

880881
XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
881882
ffi::Ffi::Bind()
882-
.Ctx<ServiceExecutableRunOptions>()
883+
.Ctx<se::Stream>()
883884
.Arg<ffi::BufferBase>() // src
884885
.Arg<ffi::BufferBase>() // dst
885886
);

xla/service/gpu/custom_call_test.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,9 @@ limitations under the License.
4949
#include "xla/hlo/ir/hlo_instructions.h"
5050
#include "xla/service/custom_call_status.h"
5151
#include "xla/service/custom_call_target_registry.h"
52-
#include "xla/service/service_executable_run_options.h"
5352
#include "xla/shape_util.h"
54-
#include "xla/status.h"
5553
#include "xla/stream_executor/gpu/gpu_types.h"
54+
#include "xla/stream_executor/stream.h"
5655
#include "xla/test_helpers.h"
5756
#include "xla/tests/client_library_test_base.h"
5857
#include "tsl/lib/core/status_test_util.h"
@@ -375,17 +374,17 @@ TEST_F(CustomCallTest, RuntimeCustomCallAlwaysFail) {
375374
EXPECT_THAT(status.message(), ::testing::HasSubstr("Uh oh, wrong value: 42"));
376375
}
377376

378-
static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options,
379-
ffi::BufferBase src, ffi::BufferBase dst) {
380-
return run_options->stream()->MemcpyD2D(
377+
static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src,
378+
ffi::BufferBase dst) {
379+
return stream->MemcpyD2D(
381380
&dst.data, src.data,
382381
absl::c_accumulate(src.dimensions, 1.0, std::multiplies<int64_t>()) *
383382
sizeof(float));
384383
}
385384

386385
XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
387386
ffi::Ffi::Bind()
388-
.Ctx<ServiceExecutableRunOptions>()
387+
.Ctx<se::Stream>()
389388
.Arg<ffi::BufferBase>() // src
390389
.Arg<ffi::BufferBase>() // dst
391390
);
@@ -620,8 +619,8 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) {
620619
//===----------------------------------------------------------------------===//
621620

622621
static absl::Status MemcpyWithCalledComputation(
623-
const ServiceExecutableRunOptions* run_options, ffi::BufferBase src,
624-
ffi::BufferBase dst, const HloComputation* called_computation) {
622+
se::Stream* stream, ffi::BufferBase src, ffi::BufferBase dst,
623+
const HloComputation* called_computation) {
625624
if (called_computation == nullptr)
626625
return absl::InternalError("Called computation is not defined");
627626

@@ -631,13 +630,13 @@ static absl::Status MemcpyWithCalledComputation(
631630
if (!DynCast<HloParameterInstruction>(called_computation->root_instruction()))
632631
return absl::InternalError("ROOT must be a paremeter");
633632

634-
return Memcpy(run_options, src, dst);
633+
return Memcpy(stream, src, dst);
635634
}
636635

637636
XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation,
638637
MemcpyWithCalledComputation,
639638
ffi::Ffi::Bind()
640-
.Ctx<ServiceExecutableRunOptions>()
639+
.Ctx<se::Stream>()
641640
.Arg<ffi::BufferBase>() // src
642641
.Arg<ffi::BufferBase>() // dst
643642
.Ctx<ffi::CalledComputation>());

xla/service/gpu/fusions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ xla_test(
107107
"//xla/service:custom_call_target_registry",
108108
"//xla/service:executable",
109109
"//xla/service:hlo_module_config",
110+
"//xla/stream_executor",
110111
"//xla/stream_executor:device_description",
111112
"//xla/stream_executor/gpu:gpu_types_header",
112113
"//xla/tests:hlo_test_base",

xla/service/gpu/fusions/address_computation_fusion_test.cc

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "xla/shape.h"
3333
#include "xla/shape_util.h"
3434
#include "xla/stream_executor/gpu/gpu_types.h"
35+
#include "xla/stream_executor/stream.h"
3536
#include "xla/tests/hlo_test_base.h"
3637
#include "tsl/platform/statusor.h"
3738
#include "tsl/platform/test.h"
@@ -844,17 +845,17 @@ TEST_F(AddressComputationFusionTest, SlicedOperandAliasingOutput) {
844845
/*run_hlo_passes=*/false));
845846
}
846847

847-
static absl::Status Memcpy(const ServiceExecutableRunOptions* run_options,
848-
ffi::BufferBase src, ffi::BufferBase dst) {
849-
return run_options->stream()->MemcpyD2D(
848+
static absl::Status Memcpy(se::Stream* stream, ffi::BufferBase src,
849+
ffi::BufferBase dst) {
850+
return stream->MemcpyD2D(
850851
&dst.data, src.data,
851852
absl::c_accumulate(src.dimensions, 1.0, std::multiplies<int64_t>()) *
852853
sizeof(float));
853854
}
854855

855856
XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy,
856857
ffi::Ffi::Bind()
857-
.Ctx<ServiceExecutableRunOptions>()
858+
.Ctx<se::Stream>()
858859
.Arg<ffi::BufferBase>() // src
859860
.Arg<ffi::BufferBase>() // dst
860861
);
@@ -894,12 +895,12 @@ TEST_F(AddressComputationFusionTest, CustomCallSimple) {
894895
/*run_hlo_passes=*/false));
895896
}
896897

897-
static absl::Status SubBuffers(const ServiceExecutableRunOptions* run_options,
898-
ffi::BufferBase src0, ffi::BufferBase src1,
899-
ffi::BufferBase src2, ffi::BufferBase src3,
900-
ffi::BufferBase src4, ffi::BufferBase dst0,
901-
ffi::BufferBase dst1, ffi::BufferBase dst2,
902-
ffi::BufferBase dst3, ffi::BufferBase dst4) {
898+
static absl::Status SubBuffers(se::Stream* stream, ffi::BufferBase src0,
899+
ffi::BufferBase src1, ffi::BufferBase src2,
900+
ffi::BufferBase src3, ffi::BufferBase src4,
901+
ffi::BufferBase dst0, ffi::BufferBase dst1,
902+
ffi::BufferBase dst2, ffi::BufferBase dst3,
903+
ffi::BufferBase dst4) {
903904
// src0: param 0 at tuple index {0}, shape f32[128]
904905
// src1: param 0 at tuple index {1}, shape f32[256]
905906
// src2: param 1 at tuple index {0}, shape f32[1024]
@@ -912,22 +913,22 @@ static absl::Status SubBuffers(const ServiceExecutableRunOptions* run_options,
912913
// dst3: result at tuple index {2}, shape f32[1024]
913914
// dst4: result at tuple index {3}, shape f32[4,8]
914915

915-
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst0.data, src3.data,
916-
8 * sizeof(float)));
917-
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst1.data, src0.data,
918-
128 * sizeof(float)));
919-
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst2.data, src1.data,
920-
256 * sizeof(float)));
921-
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst3.data, src2.data,
922-
1024 * sizeof(float)));
923-
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(&dst4.data, src4.data,
924-
4 * 8 * sizeof(float)));
916+
TF_RETURN_IF_ERROR(
917+
stream->MemcpyD2D(&dst0.data, src3.data, 8 * sizeof(float)));
918+
TF_RETURN_IF_ERROR(
919+
stream->MemcpyD2D(&dst1.data, src0.data, 128 * sizeof(float)));
920+
TF_RETURN_IF_ERROR(
921+
stream->MemcpyD2D(&dst2.data, src1.data, 256 * sizeof(float)));
922+
TF_RETURN_IF_ERROR(
923+
stream->MemcpyD2D(&dst3.data, src2.data, 1024 * sizeof(float)));
924+
TF_RETURN_IF_ERROR(
925+
stream->MemcpyD2D(&dst4.data, src4.data, 4 * 8 * sizeof(float)));
925926
return absl::OkStatus();
926927
}
927928

928929
XLA_FFI_DEFINE_HANDLER(kSubBuffers, SubBuffers,
929930
ffi::Ffi::Bind()
930-
.Ctx<ServiceExecutableRunOptions>()
931+
.Ctx<se::Stream>()
931932
.Arg<ffi::BufferBase>() // src0
932933
.Arg<ffi::BufferBase>() // src1
933934
.Arg<ffi::BufferBase>() // src2
@@ -995,15 +996,13 @@ TEST_F(AddressComputationFusionTest, CustomCallWithTuple) {
995996
/*run_hlo_passes=*/false));
996997
}
997998

998-
static absl::Status NoOp(const ServiceExecutableRunOptions* run_options,
999-
ffi::BufferBase operand) {
999+
static absl::Status NoOp(se::Stream* stream, ffi::BufferBase operand) {
10001000
return absl::OkStatus();
10011001
}
10021002

1003-
XLA_FFI_DEFINE_HANDLER(kNoOp, NoOp,
1004-
ffi::Ffi::Bind()
1005-
.Ctx<ServiceExecutableRunOptions>()
1006-
.Arg<ffi::BufferBase>() // operand
1003+
XLA_FFI_DEFINE_HANDLER(
1004+
kNoOp, NoOp,
1005+
ffi::Ffi::Bind().Ctx<se::Stream>().Arg<ffi::BufferBase>() // operand
10071006
);
10081007
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$noop", PLATFORM,
10091008
kNoOp);

0 commit comments

Comments
 (0)