Skip to content

[DOC] Can't compile a bf16xfp4 gemm #2519

@cy2018202170

Description

@cy2018202170

Report needed documentation

Report needed documentation
I am trying to write a gemm for BF16xFP4 using cutlass. I want to know if cutlass supports this operation on the 5090.

Describe the documentation you'd like
There is a document that records the mixed precision calculations supported and unsupported by various platforms.

Steps taken to search for needed documentation
I am trying to write a gemm for BF16xFP4 using cutlass. Here is some of my code:

template <typename Config, typename OutType> struct Bf16_Fp4_GemmSm120 {
    // using ElementA = cutlass::bfloat16_t;
    using ElementA = __nv_bfloat16;
    using LayoutATag = cutlass::layout::RowMajor;
    static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

    using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
    using LayoutBTag = cutlass::layout::ColumnMajor;
    static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

    using ElementD = OutType;
    using ElementC = OutType;
    using LayoutCTag = cutlass::layout::RowMajor;
    using LayoutDTag = cutlass::layout::RowMajor;
    static constexpr int AlignmentD =
        128 / cutlass::sizeof_bits<ElementD>::value;
    static constexpr int AlignmentC =
        128 / cutlass::sizeof_bits<ElementC>::value;

    using ElementAccumulator = float;
    using ArchTag = cutlass::arch::Sm120;
    using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;

    using MmaTileShape = typename Config::MmaTileShape;
    using ClusterShape = typename Config::ClusterShape;
    using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;

    using CollectiveEpilogue =
        typename cutlass::epilogue::collective::CollectiveBuilder<
            ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
            cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
            ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
            LayoutDTag, AlignmentD,
            cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

    using CollectiveMainloop =
        typename cutlass::gemm::collective::CollectiveBuilder<
            ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
            LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
            ClusterShape,
            cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
                sizeof(typename CollectiveEpilogue::SharedStorage))>,
            cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;

    using GemmKernel =
        cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
                                             CollectiveMainloop,
                                             CollectiveEpilogue, void>;

    using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

when I compile my code, I get this error:
.../cutlass/gemm/collective/collective_builder_decl.hpp(94): error: static assertion failed with "Could not build a collective for given parameters." static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");

here is my compile command:
/home/cheny/miniconda3/envs/venv_chitu_2/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/cheny/cinfer/build/temp.linux-x86_64-cpython-312/home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.o.d -I/home/cheny/cinfer/csrc/../third_party/spdlog/include -I/home/cheny/cinfer/csrc/cuda/common -I/home/cheny/miniconda3/envs/venv_chitu_2/lib/python3.12/site-packages/torch/include -I/home/cheny/miniconda3/envs/venv_chitu_2/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/cheny/miniconda3/envs/venv_chitu_2/include -I/home/cheny/miniconda3/envs/venv_chitu_2/include/python3.12 -c -c /home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.cu -o /home/cheny/cinfer/build/temp.linux-x86_64-cpython-312/home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DENABLE_NVFP4 -gencode=arch=compute_120a,code=sm_120a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=chitu_backend -D_GLIBCXX_USE_CXX11_ABI=1 -ccbin /home/cheny/miniconda3/envs/venv_chitu_2/bin/x86_64-conda-linux-gnu-cc -std=c++17

I want to know if it was my programming error or if Cutlass currently does not support such mixed precision calculations.

This is my environmental information:
GPU: NVIDIA GeForce RTX 5090
nvcc: release 12.8, V12.8.93
cmake: 4.0.3

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions