- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.5k
Open
Labels
Description
Steps/Code to reproduce bug
import torch
import cutlass.epilogue
def epilogue(accum, bias):
    D = accum + bias
    return D
examples_tensors = dict(
    accum=torch.randn(1024, 1024),
    bias=torch.randn(1024, 1).bfloat16(),
    D=torch.randn(1024, 1024).bfloat16(),
)
cutlass.epilogue.trace(epilogue, examples_tensors)Error
File ~/miniconda3/envs/dev_nightly/lib/python3.10/site-packages/cutlass/backend/evt/ir/load_nodes.py:187, in ColumnBroadcastImpl.argument_type.<locals>._Argument()
    184 class _Argument(ctypes.Structure):
    185     _fields_ = [
    186         ("ptr_col", ctypes.c_void_p),
--> 187         ("null_default", dtype2ctype[element_type]),
    188         ("dCol", tuple_type)
    189     ]
    190     def __init__(self, kwargs) -> None:
    191         ptr = kwargs[name]
KeyError: <DataType.bf16: 16>
Upon inspecting the code, I found that simply adding DataType.bf16: ctypes.c_uint16 to dtype2ctype will make things work. I verify that the outputs are correct.
cutlass/python/cutlass/backend/epilogue.py
Lines 45 to 51 in 1ebda1c
| dtype2ctype = { | |
| DataType.f16: ctypes.c_uint16, | |
| DataType.f32: ctypes.c_float, | |
| DataType.f64: ctypes.c_double, | |
| DataType.s8: ctypes.c_int8, | |
| DataType.s32: ctypes.c_int32 | |
| } |