From aa215e23acf70cb5abf178571a058d737729d6e5 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 15 Mar 2024 14:09:25 +0000 Subject: [PATCH] Use the strides of the memref descriptor to construct the TMA descriptor The previous version of the code assumed that the tensor was contiguous, which is not required and can cause surprising miscompiles. --- .../ExecutionEngine/CudaRuntimeWrappers.cpp | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index b9a3429e37b88..c76f8d77dff55 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -427,13 +427,21 @@ namespace { template void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr, - uint64_t *globalDim) { + uint64_t *globalDim, uint64_t *globalStrides, + const CUtensorMapDataType tensorDataType) { auto descriptor = reinterpret_cast *>(raw_descriptor); *addr = descriptor->data; for (int i = 0; i < rank; ++i) { globalDim[i] = static_cast(descriptor->sizes[rank - i - 1]); } + static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, + 4, 8, 2, 4, 4, 4}; + // TODO(grypp): Check that the minormost stride is equal to the element size. + for (int i = 0; i < rank - 1; ++i) { + globalStrides[i] = static_cast( + descriptor->strides[rank - i - 2] * elementSizeInBytes[tensorDataType]); + } } } // namespace @@ -457,19 +465,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( char *globalAddress = nullptr; switch (tensorRank) { case 1: - mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 2: - mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 3: - mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 4: - mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; case 5: - mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim); + mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim, + globalStrides, tensorDataType); break; default: fprintf( @@ -478,17 +491,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref( return NULL; } - static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2, - 4, 8, 2, 4, 4, 4}; for (int64_t r = 0; r < tensorRank; ++r) { - elementStrides[r] = uint32_t(1); boxDim[r] = static_cast(inputBoxDims[tensorRank - r - 1]); } - globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType]; - for (int r = 1; r < tensorRank - 1; r++) - globalStrides[r] = globalStrides[r - 1] * globalDim[r]; - ScopedContext scopedContext; mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32, globalAddress, globalDim, globalStrides, boxDim,