-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
I'll preface this bug report with a note that it's totally possible I'm missing something here and have a race condition of some kind.
But basically, I'm doing two things:
- I'm allocating smem
sA_layout = cute.make_layout((blk, blk), stride=(blk, 1))
sB_layout = cute.make_layout((blk, blk), stride=(1, blk))
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(gA.element_type, sA_layout)
sB = smem.allocate_tensor(gB.element_type, sB_layout)
- I'm assigning each thread to compute one of the elements in the output tile
local_a = sA[(tidx, None)]
local_b = sB[(None, tidy)]
local_a = cute.local_tile(sA, (1, blk), coord=(tidx, 0))
local_b = cute.local_tile(sB, (blk, 1), coord=(0, tidy))
- I'm computing gemm on
local_aandlocal_bwith TensorSSA and "broadcasting".
def vector_mm(a, b):
broadcast_a = cute.make_tensor(a.iterator, cute.append(a.layout, cute.make_layout(b.shape[1], stride=0)))
broadcast_b = cute.make_tensor(b.iterator, cute.prepend(b.layout, cute.make_layout(a.shape[0], stride=0)))
return (broadcast_a.load().to(cute.Float32) * broadcast_b.load().to(cute.Float32)).reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1, None))
This gives me the wrong result (all zeros for an input tensor with all ones).
However, if I allocate smem instead as
sA_layout = cute.make_layout((blk, blk), stride=(1, blk))
sB_layout = cute.make_layout((blk, blk), stride=(blk, 1))
it passes.
Adding some more intermediate logging, broadcast_a and broadcast_b are all 1s (expected), but somehow after they're loaded and I multiply, the result is all zeros 🤔
Of note, if I replace the TensorSSA implementation with a naive for-loop version it also passes successfully.
@cute.jit
def loop_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
tmp_val = cute.make_fragment((a.shape[0], b.shape[1]), dtype=cute.Float32)
for m in cutlass.range_constexpr(a.shape[0]):
for n in cutlass.range_constexpr(b.shape[1]):
tmp_val[(m, n)] = 0.0
for k in cutlass.range_constexpr(a.shape[1]):
tmp_val[(m, n)] += a[(m, k)] * b[(k, n)]
return tmp_val.load()
Steps/Code to reproduce bug
Code can be found here: https://pastebin.com/Z7b7vnDq
Expected behavior
Environment details (please complete the following information):
nvidia-cutlass-dsl 4.1.0.dev0