Skip to content

[BUG] silent correctness bug(?) with TensorSSA + broadcasting #2518

@Chillee

Description

@Chillee

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
I'll preface this bug report with a note that it's totally possible I'm missing something here and have a race condition of some kind.

But basically, I'm doing two things:

  1. I'm allocating smem
    sA_layout = cute.make_layout((blk, blk), stride=(blk, 1))  
    sB_layout = cute.make_layout((blk, blk), stride=(1, blk))

    smem = cutlass.utils.SmemAllocator()
    sA = smem.allocate_tensor(gA.element_type, sA_layout)
    sB = smem.allocate_tensor(gB.element_type, sB_layout)
  1. I'm assigning each thread to compute one of the elements in the output tile
        local_a = sA[(tidx, None)]
        local_b = sB[(None, tidy)]
        local_a = cute.local_tile(sA, (1, blk), coord=(tidx, 0))
        local_b = cute.local_tile(sB, (blk, 1), coord=(0, tidy))
  1. I'm computing gemm on local_a and local_b with TensorSSA and "broadcasting".
def vector_mm(a, b):
    broadcast_a = cute.make_tensor(a.iterator, cute.append(a.layout, cute.make_layout(b.shape[1], stride=0)))
    broadcast_b = cute.make_tensor(b.iterator, cute.prepend(b.layout, cute.make_layout(a.shape[0], stride=0)))
    return (broadcast_a.load().to(cute.Float32) * broadcast_b.load().to(cute.Float32)).reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1, None))

This gives me the wrong result (all zeros for an input tensor with all ones).

However, if I allocate smem instead as

    sA_layout = cute.make_layout((blk, blk), stride=(1, blk))  
    sB_layout = cute.make_layout((blk, blk), stride=(blk, 1))

it passes.

Adding some more intermediate logging, broadcast_a and broadcast_b are all 1s (expected), but somehow after they're loaded and I multiply, the result is all zeros 🤔

Image

Of note, if I replace the TensorSSA implementation with a naive for-loop version it also passes successfully.

@cute.jit
def loop_mm(a: cute.Tensor, b: cute.Tensor, tile_K_idx):
    tmp_val = cute.make_fragment((a.shape[0], b.shape[1]), dtype=cute.Float32)
    for m in cutlass.range_constexpr(a.shape[0]):
        for n in cutlass.range_constexpr(b.shape[1]):
            tmp_val[(m, n)] = 0.0
            for k in cutlass.range_constexpr(a.shape[1]):
                tmp_val[(m, n)] += a[(m, k)] * b[(k, n)]
    return tmp_val.load()

Steps/Code to reproduce bug
Code can be found here: https://pastebin.com/Z7b7vnDq

Expected behavior

Environment details (please complete the following information):
nvidia-cutlass-dsl 4.1.0.dev0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions