Skip to content

Commit 2848cda

Browse files
author
jax authors
committed
Merge pull request #20341 from ROCm:rocm_add_hipStreamWaitEvent
PiperOrigin-RevId: 617893634
2 parents 291a5cd + 8575055 commit 2848cda

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

jaxlib/gpu/triton_kernels.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "jaxlib/gpu/vendor.h"
3232
#include "xla/service/custom_call_status.h"
3333
#include "xla/stream_executor/gpu/asm_compiler.h"
34+
#include "tsl/platform/env.h"
3435

3536
#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))
3637

jaxlib/gpu/vendor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
496496
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
497497
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
498498
#define gpuStreamSynchronize hipStreamSynchronize
499+
#define gpuStreamWaitEvent hipStreamWaitEvent
499500
#define gpuSuccess hipSuccess
500501

501502
#define gpuCtxGetDevice hipCtxGetDevice

0 commit comments

Comments
 (0)