-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Closed
Labels
Description
Hi,
I am trying to do a very simple TMA load on H100 and my code is producing an illegal instruction exception.
Small reproducer:
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.runtime import from_dlpack
M, K = 128, 64
@cute.kernel
def copy_kernel(tma_atom_a: cute.CopyAtom,
mA_tma: cute.Tensor,
sA_layout: cute.ComposedLayout):
tidx, _, _ = cute.arch.thread_idx()
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(cutlass.BFloat16, sA_layout.outer, 16, sA_layout.inner)
mbar = smem.allocate_array(cutlass.Uint64, 1)
if tidx == 0:
cute.arch.mbarrier_init_arrive_cnt(mbar, 1)
cute.arch.mbarrier_init_fence()
cute.arch.mbarrier_init_tx_bytes(mbar, M * K * 2)
cute.arch.sync_threads()
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
atom=tma_atom_a,
cta_coord=(0, 0),
cta_layout=cute.make_layout((1, 1)),
smem_tensor=sA,
gmem_tensor=mA_tma
)
if tidx == 0:
# Causes illegal isntruction exception, goes away if commented out
cute.copy(
tma_atom_a,
tAgA,
tAsA,
tma_bar_ptr=mbar
)
# Hangs indefinitely if this is uncommented
# cute.arch.mbarrier_wait(mbar, 0)
@cute.jit
def launch_copy(mA: cute.Tensor):
sw128_k_atom = cute.nvgpu.warpgroup.make_smem_layout_atom(
kind=cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128,
element_type=cutlass.BFloat16
)
sA_layout = cute.tile_to_shape(sw128_k_atom, (M, K), (0, 1))
basic_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tma_tile_atom(
op=basic_tma_op,
gmem_tensor=mA,
smem_layout=sA_layout,
cta_tiler=(M, K)
)
smem_size = cute.size_in_bytes(cutlass.BFloat16, sA_layout) + 8
copy_kernel(
tma_atom_a,
tma_tensor_a,
sA_layout
).launch(
grid=(1, 1, 1),
block=(128, 1, 1),
smem=smem_size
)
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
A_tensor = from_dlpack(A, assumed_align=16)
cutlass.cuda.initialize_cuda_context()
copy = cute.compile(launch_copy, A_tensor)
copy(A)
print("launched")
torch.cuda.synchronize()
It's pretty likely that I am doing something wrong here, but it's a bit tricky to debug an illegal instruction exception. I've tried to look at this: https://github.com/NVIDIA/cutlass/blob/main/examples/cute/tutorial/hopper/wgmma_tma_sm90.cu
and replicate similar behavior but in the Python DSL, but something seems to be going wrong. Any ideas?
Thanks!