- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.5k
Description
Describe the bug
Some files are missing the headers that they rely on, which means they cannot be included by themselves. This is "hidden" in most of the examples because they import many things and one of the other headers will often include what's needed, but this is fragile and has caused build failures as we rearrange includes. I assume everything in cutlass/include and cutlass/tools/util/include/ are supposed to be includable individually? If so, then the following headers are broken (generated with a simple shell script):
Broken headers
cute/algorithm/cooperative_copy.hpp
cute/algorithm/copy.hpp
cute/algorithm/functional.hpp
cute/algorithm/gemm.hpp
cute/algorithm/prefetch.hpp
cute/arch/mma_sm90.hpp
cute/atom/copy_atom.hpp
cute/atom/copy_traits_sm90_tma.hpp
cute/atom/mma_traits_sm90.hpp
cute/atom/mma_traits_sm90_gmma.hpp
cute/layout_composed.hpp
cute/numeric/int.hpp
cute/numeric/numeric_types.hpp
cute/pointer_flagged.hpp
cutlass/arch/mma_sm61.h
cutlass/arch/mma_sm89.h
cutlass/arch/reg_reconfig.h
cutlass/arch/simd.h
cutlass/arch/simd_sm60.h
cutlass/arch/simd_sm61.h
cutlass/arch/wmma_sm70.h
cutlass/arch/wmma_sm72.h
cutlass/arch/wmma_sm75.h
cutlass/conv/collective/builders/sm90_gmma_builder.inl
cutlass/conv/collective/collective_builder.hpp
cutlass/conv/collective/collective_conv.hpp
cutlass/conv/collective/detail.hpp
cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp
cutlass/conv/convnd_problem_shape.hpp
cutlass/conv/device/implicit_gemm_convolution.h
cutlass/conv/device/implicit_gemm_convolution_fusion.h
cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h
cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp
cutlass/conv/threadblock/conv2d_tile_iterator.h
cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h
cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h
cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h
cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h
cutlass/conv/threadblock/depthwise_mma_base.h
cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h
cutlass/conv/threadblock/implicit_gemm_multistage.h
cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h
cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h
cutlass/device_kernel.h
cutlass/epilogue/collective/builders/sm90_builder.inl
cutlass/epilogue/collective/builders/sm90_common.inl
cutlass/epilogue/collective/collective_builder.hpp
cutlass/epilogue/collective/collective_epilogue.hpp
cutlass/epilogue/collective/detail.hpp
cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp
cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp
cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp
cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp
cutlass/epilogue/fusion/callbacks.hpp
cutlass/epilogue/fusion/operations.hpp
cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp
cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp
cutlass/epilogue/thread/linear_combination_params.h
cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp
cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h
cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op_blas3.h
cutlass/epilogue/threadblock/default_epilogue_direct_store.h
cutlass/epilogue/threadblock/default_epilogue_tensor_op_blas3.h
cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h
cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h
cutlass/epilogue/threadblock/epilogue_base_streamk.h
cutlass/epilogue/threadblock/epilogue_depthwise.h
cutlass/epilogue/threadblock/epilogue_direct_store.h
cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h
cutlass/epilogue/threadblock/epilogue_smem_accumulator.h
cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h
cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h
cutlass/epilogue/threadblock/epilogue_workspace.h
cutlass/epilogue/threadblock/fusion/visitor_2x.hpp
cutlass/epilogue/threadblock/fusion/visitor_compute.hpp
cutlass/epilogue/threadblock/fusion/visitor_load.hpp
cutlass/epilogue/threadblock/fusion/visitor_store.hpp
cutlass/epilogue/threadblock/fusion/visitors.hpp
cutlass/epilogue/threadblock/output_tile_thread_map.h
cutlass/epilogue/threadblock/predicated_tile_iterator_blas3.h
cutlass/epilogue/threadblock/shared_load_iterator.h
cutlass/epilogue/threadblock/shared_load_iterator_mixed.h
cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h
cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h
cutlass/epilogue/warp/fragment_iterator_gaussian_complex_tensor_op.h
cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h
cutlass/epilogue/warp/tile_iterator_simt.h
cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h
cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h
cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h
cutlass/floating_point_nvrtc.h
cutlass/gemm/collective/fp8_accumulation.hpp
cutlass/gemm/collective/sm80_mma_multistage.hpp
cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp
cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp
cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp
cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp
cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp
cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp
cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp
cutlass/gemm/device/rank_2k_grouped.h
cutlass/gemm/kernel/default_gemm_streamk_with_broadcast.h
cutlass/gemm/kernel/default_rank_2k_grouped.h
cutlass/gemm/kernel/gemm_params.h
cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h
cutlass/gemm/kernel/params_sparse_base.h
cutlass/gemm/kernel/sm70_gemm.hpp
cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp
cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp
cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp
cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp
cutlass/gemm/thread/mma_sm50.h
cutlass/gemm/threadblock/default_ell_mma.h
cutlass/gemm/threadblock/default_mma_core.h
cutlass/gemm/threadblock/default_mma_core_sm70.h
cutlass/gemm/threadblock/default_mma_core_with_access_size.h
cutlass/gemm/threadblock/default_mma_core_wmma.h
cutlass/gemm/threadblock/ell_mma_multistage.h
cutlass/gemm/threadblock/ell_mma_pipelined.h
cutlass/gemm/threadblock/mma_blas3_multistage.h
cutlass/gemm/threadblock/mma_multistage.h
cutlass/gemm/threadblock/mma_planar_complex_base.h
cutlass/gemm/threadblock/mma_planar_complex_multistage.h
cutlass/gemm/threadblock/mma_planar_complex_pipelined.h
cutlass/gemm/threadblock/mma_singlestage.h
cutlass/gemm/threadblock/mma_sparse_base.h
cutlass/gemm/threadblock/mma_sparse_multistage.h
cutlass/gemm/threadblock/mma_with_reduction_multistage.h
cutlass/layout/tensor_op_multiplicand_sm70.h
cutlass/pipeline/pipeline.hpp
cutlass/pipeline/sm90_pipeline.hpp
cutlass/real.h
cutlass/reduction/device/reduce_split_k.h
cutlass/reduction/kernel/reduce_softmax_final.h
cutlass/reduction/threadblock_swizzle.h
cutlass/tensor_view_planar_complex.h
cutlass/tfloat32.h
cutlass/thread/matrix.h
cutlass/transform/collective/sm90_wgmma_transpose.hpp
cutlass/transform/thread/transpose.h
cutlass/transform/threadblock/ell_iterator.h
cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h
cutlass/transform/threadblock/predicated_vector_access_iterator.h
cutlass/transform/threadblock/regular_tile_access_iterator.h
cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h
cutlass/transform/threadblock/vector_iterator.h
cutlass/util/cublas_wrappers.hpp
cutlass/util/device_rmsnorm.h
cutlass/util/host_tensor_planar_complex.h
cutlass/util/packed_stride.hpp
cutlass/util/reference/device/gemm.h
cutlass/util/reference/device/gett.hpp
cutlass/util/reference/device/kernel/gemm.h
cutlass/util/reference/device/kernel/tensor_elementwise.h
cutlass/util/reference/device/rank_2k_complex.h
cutlass/util/reference/device/thread/gemm.h
cutlass/util/reference/host/conv.hpp
cutlass/util/reference/host/error_metrics.h
cutlass/util/reference/host/symm_complex.h
cutlass/util/reference/host/tensor_copy.h
cutlass/util/reference/host/tensor_elementwise.h
cutlass/util/reference/host/tensor_foreach.h
cutlass/util/reference/host/tensor_norm.h
cutlass/util/reference/host/tensor_reduce.h
Steps/Code to reproduce bug
Here is an example:
#include <cutlass/epilogue/collective/collective_builder.hpp>Build this with -c -std=c++20 -I cutlass/include/ -I cutlass/tools/util/include.
Expected behavior
This should compile cleanly.
Environment details (please complete the following information):
- Environment location: Bare-metal
- CUTLASS main branch
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0Additional context
Example output
$ nvcc -c -std=c++20 test.cu -I cutlass/include/ -I cutlass/tools/util/include/
cutlass/include/cutlass/epilogue/fusion/operations.hpp(134): error: identifier "sizeof_bits_v" is undefined
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: identifier "sizeof_bits_v" is undefined
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(169): error: identifier "sizeof_bits_v" is undefined
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(190): error: identifier "sizeof_bits_v" is undefined
    int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
                              ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: type name is not allowed
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                             ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: expected a "," or ">"
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                                         ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(211): error: identifier "sizeof_bits_v" is undefined
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(212): error: type name is not allowed
    int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
                                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(212): error: expected a "," or ">"
    int AlignmentScalar_ = 128 / sizeof_bits_v<ElementScalar_>,
                                                             ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(234): error: identifier "sizeof_bits_v" is undefined
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                               ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(264): error: identifier "sizeof_bits_v" is undefined
    int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
                              ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(265): error: type name is not allowed
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                             ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(265): error: expected a "," or ">"
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                                         ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(291): error: identifier "sizeof_bits_v" is undefined
    int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
                              ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(318): error: identifier "sizeof_bits_v" is undefined
    int AlignmentAux_ = 128 / sizeof_bits_v<ElementAux_>,
                              ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(319): error: type name is not allowed
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                             ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(319): error: expected a "," or ">"
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                                         ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: identifier "cute" is undefined
    cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
    ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: enable_if_t is not a template
    cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
          ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: identifier "is_base_of_v" is undefined
    cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
                          ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: expected a ">"
    cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
        ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(103): error: type name is not allowed
    cute::enable_if_t<not is_base_of_v<fusion::FusionOperation, FusionCallbacks>>
                                       ^
cutlass/include/cutlass/epilogue/collective/collective_builder.hpp(104): error: expected a ";"
  > {
  ^
cutlass/include/cute/arch/mma_sm90.hpp(154): error: complex is not a template
    using DRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(155): error: complex is not a template
    using ARegisters = complex<double>[2];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(156): error: complex is not a template
    using BRegisters = complex<double>[1];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(157): error: complex is not a template
    using CRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(160): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(160): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(161): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(161): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(162): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(162): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(163): error: complex is not a template
        complex<double> const& b0,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(164): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(164): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(165): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
        ^
cutlass/include/cute/arch/mma_sm90.hpp(165): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(212): error: complex is not a template
    using DRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(213): error: complex is not a template
    using ARegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(214): error: complex is not a template
    using BRegisters = complex<double>[2];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(215): error: complex is not a template
    using CRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(218): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(218): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(219): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(219): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(220): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(220): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(221): error: complex is not a template
        complex<double> const& a2, complex<double> const& a3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(221): error: complex is not a template
        complex<double> const& a2, complex<double> const& a3,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(222): error: complex is not a template
        complex<double> const& b0, complex<double> const& b1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(222): error: complex is not a template
        complex<double> const& b0, complex<double> const& b1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(223): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(223): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(224): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
        ^
cutlass/include/cute/arch/mma_sm90.hpp(224): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(271): error: complex is not a template
    using DRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(272): error: complex is not a template
    using ARegisters = complex<double>[8];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(273): error: complex is not a template
    using BRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(274): error: complex is not a template
    using CRegisters = complex<double>[4];
                       ^
cutlass/include/cute/arch/mma_sm90.hpp(277): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(277): error: complex is not a template
    fma(complex<double> & d0, complex<double> & d1,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(278): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(278): error: complex is not a template
        complex<double> & d2, complex<double> & d3,
                              ^
cutlass/include/cute/arch/mma_sm90.hpp(279): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(279): error: complex is not a template
        complex<double> const& a0, complex<double> const& a1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(280): error: complex is not a template
        complex<double> const& a2, complex<double> const& a3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(280): error: complex is not a template
        complex<double> const& a2, complex<double> const& a3,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(281): error: complex is not a template
        complex<double> const& a4, complex<double> const& a5,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(281): error: complex is not a template
        complex<double> const& a4, complex<double> const& a5,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(282): error: complex is not a template
        complex<double> const& a6, complex<double> const& a7,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(282): error: complex is not a template
        complex<double> const& a6, complex<double> const& a7,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(283): error: complex is not a template
        complex<double> const& b0, complex<double> const& b1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(283): error: complex is not a template
        complex<double> const& b0, complex<double> const& b1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(284): error: complex is not a template
        complex<double> const& b2, complex<double> const& b3,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(284): error: complex is not a template
        complex<double> const& b2, complex<double> const& b3,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(285): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
        ^
cutlass/include/cute/arch/mma_sm90.hpp(285): error: complex is not a template
        complex<double> const& c0, complex<double> const& c1,
                                   ^
cutlass/include/cute/arch/mma_sm90.hpp(286): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
        ^
cutlass/include/cute/arch/mma_sm90.hpp(286): error: complex is not a template
        complex<double> const& c2, complex<double> const& c3)
                                   ^
cutlass/include/cutlass/epilogue/fusion/operations.hpp(191): error: expected a ","
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                                         ^
          detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux" based on template arguments <GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementT, Schedule::ElementBias, ElementC_, ElementCompute, <error-constant>> at line 696 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: type name is not allowed
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                             ^
          detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltAct" based on template arguments <Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementBias, ElementC_, ElementCompute> at line 704 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
cutlass/include/cutlass/epilogue/fusion/operations.hpp(152): error: expected a ","
    int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
                                                         ^
          detected during processing of template argument list for "cutlass::epilogue::fusion::LinCombPerRowBiasEltAct" based on template arguments <Schedule::ActivationFunctor, ElementD, ElementCompute, Schedule::ElementBias, ElementC_, ElementCompute> at line 704 of cutlass/include/cutlass/epilogue/collective/builders/sm90_builder.inl
83 errors detected in the compilation of "test.cu".