diff --git a/include/cutlass/epilogue/thread/conversion_op.h b/include/cutlass/epilogue/thread/conversion_op.h index 432906acf6..19bbc03a91 100644 --- a/include/cutlass/epilogue/thread/conversion_op.h +++ b/include/cutlass/epilogue/thread/conversion_op.h @@ -62,6 +62,7 @@ class Convert { using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementAccumulator_; + using ElementD = ElementOutput; // for use with cute::collective::DefaultEpilogue static int const kCount = Count; @@ -123,6 +124,21 @@ class Convert { return destination_converter(accumulator); } + + // + // Specializations for scalar (for use with cute::collective::DefaultEpilogue) + // + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator, ElementAccumulator const source) const { + NumericConverter destination_converter; + return destination_converter(source); + } + + CUTLASS_HOST_DEVICE + ElementD operator()(ElementAccumulator const accumulator) const { + NumericConverter destination_converter; + return destination_converter(accumulator); + } }; /////////////////////////////////////////////////////////////////////////////////////////////////