From b35a0704de3324ef028f0f0280450420e624fa48 Mon Sep 17 00:00:00 2001 From: SimonDanisch Date: Wed, 6 Dec 2017 12:32:42 +0100 Subject: [PATCH] use cutlass for a pure julia fallback --- src/blasfallback.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 src/blasfallback.jl diff --git a/src/blasfallback.jl b/src/blasfallback.jl new file mode 100644 index 00000000..464f5084 --- /dev/null +++ b/src/blasfallback.jl @@ -0,0 +1,36 @@ + +function block_matrix_product(state, A, B, C) where {ThreadItemsY, ThreadItemsX} + + # Fragments used to store data fetched from SMEM + frag_a = @LocalMemory(state, T, ThreadItemsY) + frag_b = @LocalMemory(state, T, ThreadItemsX) + + # Accumulator storage + accumulator = @LocalMemory(state, T, ThreadItemsX, ThreadItemsY) + + # GEMM Mainloop - iterates over the entire K dimension - not unrolled + for kblock in Int32(1):BlockItemsK:K_dim + # Load A and B tiles from global memory and store to SMEM + # + # (not shown for brevity - see the CUTLASS source for more detail) + + synchronize_threads(state) + # Warp tile structure - iterates over the Thread Block tile + #pragma unroll + for warp_k in Int32(1):WarpItemsK:BlockItemsK + # Fetch frag_a and frag_b from SMEM corresponding to k-index + # + # (not shown for brevity - see CUTLASS source for more detail) + + # Thread tile structure - accumulate an outer product + #pragma unroll + for thread_x in Int32(1):ThreadItemsX + #pragma unroll + for thread_y in Int32(1):ThreadItemsY + accumulator[thread_x, thread_y] += frag_a[y] * frag_b[x] + end + end + end + synchronize_threads(state) + end +end