-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[QST] Epilogue Broadcast: Adapter
vs GemmUniversal
#1459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
As a follow-up, trying to implement the above using epilogue visitor trees. Encountering 2 problems:
I tried to tweak the streamk with broadcast example with the above changes (simpler EVT and Below is the full script: #include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cute/tensor.hpp"
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) \
{ \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
using DType = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
// EVT
constexpr int Alignment = 128 / cutlass::sizeof_bits_v<DType>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DType,
Alignment,
EVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, DType, DType,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Scale,
Accum>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute0>;
using EVTKernel =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DType, LayoutA, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutB, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutC, Alignment,
DType,
DType,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
ThreadBlockSwizzle,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages>::GemmKernel;
using DeviceGemmEVT = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
DType alpha = DType(1.0), DType beta = DType(0.0))
{
cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Z;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Broadcast;
tensor_A.resize({problem_size.m(), problem_size.k()});
tensor_B.resize({problem_size.k(), problem_size.n()});
tensor_Z.resize({problem_size.m(), problem_size.n()});
tensor_Broadcast.resize({problem_size.m(), 1});
cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), DType(1.0));
cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), DType(0.0));
cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_Z.sync_device();
tensor_Broadcast.sync_device();
if (verbose)
{
std::cout << "tensor_A:\n"
<< tensor_A.host_view() << std::endl;
std::cout << "tensor_B:\n"
<< tensor_B.host_view() << std::endl;
std::cout << "tensor_Broadcast:\n"
<< tensor_Broadcast.host_view() << std::endl;
}
typename EVTD::Arguments callback_args{
{
{}, // Compute0
{tensor_Broadcast.device_data(), DType(0), {_0{}, _1{}, int32_t(problem_size.m())}}, // bias / scale
{} // Accum
}, // EvtCompute0
{tensor_Z.device_data(), {problem_size.n(), _1{}, problem_size.mn().product()}}, // D
};
typename EVTKernel::Arguments evtArgs{
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
problem_size, // problem_size
1, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_A.device_data(), // ptr_A
tensor_B.device_data(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
problem_size.mk().product(), // batch_stride_A
problem_size.nk().product(), // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_A.layout().stride(0), // stride_a
tensor_B.layout().stride(0), // stride_b
0, // stride_c (unused)
0 // stride_d (unused)
};
}
int main()
{
int M = 8;
int N = 8;
int K = 8;
std::cout << "GemmEVT" << std::endl;
test<EVTKernel>(M, N, K);
} |
Not an expert but I recently made the exactly same problem when crafting my custom epilogue visitor tree. Here is what I think :
To
Do the same with type |
For your second problem, the reason why the example works well with
to me more specific, the example 47 uses |
This issue has been labeled |
This issue has been labeled |
Any update on this? I'm trying to do very similar thing, and facing similar issue
The funny thing is, I first tried this with cutlass python, it works! I did use the code generated from cutlass python as reference to build the kernel, but I cannot see how the kernel is called. So I'm suspecting something wrong with how I call the kernel. Code reproduction#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include <torch/extension.h>
torch::Tensor int4_mm_dequant(torch::Tensor A, torch::Tensor B, torch::Tensor row_scale, torch::Tensor col_scale) {
int M = A.size(0);
int K = A.size(1) * 2;
int N = B.size(1);
torch::Tensor C = torch::empty({M, N}, row_scale.options());
// follow this https://github.com/NVIDIA/cutlass/issues/1565
// and with the help of emitted code from cutlass Python
using ElementA = cutlass::int4b_t;
using ElementB = cutlass::int4b_t;
using ElementC = cutlass::bfloat16_t;
using ElementAccumulator = int32_t;
using ElementEpilogue = float;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // 32
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // 32
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // 8
// some configs
// https://github.com/NVIDIA/cutlass/blob/v3.5.1/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu
using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
constexpr int numEpilogueStages = 1;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape, WarpShape, ElementC, AlignmentC, numEpilogueStages
>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
constexpr auto RoundMode = cutlass::FloatRoundStyle::round_to_nearest;
// (1, N)
using ColScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<cute::_0, cute::_1, int32_t> // MNL
>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementEpilogue, ElementEpilogue, RoundMode
>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<Compute0, Accum, ColScale>;
// (M, 1)
using RowScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<cute::_1, cute::_0, int32_t> // MNL
>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementC, ElementEpilogue, RoundMode
>;
using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT<Compute1, EVTCompute0, RowScale>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementC, RoundMode,
cute::Stride<int64_t, cute::_1, int64_t> // MNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute1>;
using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementA, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, AlignmentA,
ElementB, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, AlignmentB,
ElementC, cutlass::layout::RowMajor, AlignmentC,
ElementAccumulator, ElementEpilogue,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3, // numStages
cutlass::arch::OpMultiplyAddSaturate,
numEpilogueStages
>::GemmKernel;
using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
const ElementC *col_scale_ptr = reinterpret_cast<ElementC *>(col_scale.data_ptr<torch::BFloat16>());
const ElementC *row_scale_ptr = reinterpret_cast<ElementC *>(row_scale.data_ptr<torch::BFloat16>());
ElementC *C_ptr = reinterpret_cast<ElementC *>(C.data_ptr<torch::BFloat16>());
typename EVTD::Arguments callback_args{
{
{
{}, // Accum
{col_scale_ptr, ElementC(0), {cute::_0{}, cute::_1{}, int32_t(N)}}, // ColScale
{} // Multiply
}, // EVTCompute0
{row_scale_ptr, ElementC(0), {cute::_1{}, cute::_0{}, int32_t(M)}}, // RowScale
{} // Multiply
}, // EVTCompute1
{C_ptr, {int64_t{N}, cute::_1{}, int64_t{M*N}}} // D
};
// NOTE: argument list is based on EVTKernel
typename DeviceGemm::Arguments args(
cutlass::gemm::GemmUniversalMode::kGemm,
cutlass::gemm::GemmCoord{M, N, K},
1,
callback_args,
reinterpret_cast<ElementA *>(A.data_ptr<int8_t>()),
reinterpret_cast<ElementB *>(B.data_ptr<int8_t>()),
nullptr,
nullptr,
M * K,
N * K,
0,
0,
A.stride(0),
B.stride(0),
0,
0
);
DeviceGemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(args));
CUTLASS_CHECK(gemm_op.initialize(args));
CUTLASS_CHECK(gemm_op());
return C;
} Cutlass python that works (I have verified the outputs are correct)import cutlass
import cutlass.op
import torch
from cuda import cuda
from cutlass.backend.evt import EpilogueFunctorVisitor
from cutlass.backend.gemm_operation import GemmArguments
from cutlass.shape import GemmCoord
from cutlass_library import DataType, GemmUniversalMode
# override Gemm.run to set problem_size correctly for int4
class MyGemm(cutlass.op.Gemm):
def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
stream: cuda.CUstream = cuda.CUstream(0), _problem_size=None) -> GemmArguments:
super().run_setup()
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
is_void_c = self._element_c == DataType.void
self._verify_rank(A)
self._verify_rank(B)
if not is_void_c:
self._verify_rank(C)
self._verify_rank(D)
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
# kernels, for which `C` is None.
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
alignment_C=alignment_c, print_module=print_module)
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
# added these 2 lines
if _problem_size is not None:
problem_size = GemmCoord(*_problem_size)
if mode == GemmUniversalMode.Gemm or batch_count == 1:
kwargs = {'split_k_slices': 1}
else:
kwargs = {
'batch': batch_count,
'batch_strides': {
'A': self._get_batch_stride(A),
'B': self._get_batch_stride(B),
'C': self._get_batch_stride(C),
'D': self._get_batch_stride(D)
}
}
kwargs['stream'] = stream
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
output_op = self.operation.epilogue_type(visitor_args)
else:
output_op = self.operation.epilogue_type(alpha, beta)
arguments = GemmArguments(
operation=self.operation, problem_size=problem_size,
A=A, B=B, C=C, D=D,
output_op=output_op,
gemm_mode=mode,
**kwargs
)
self.operation.run(arguments)
if sync:
arguments.sync()
return arguments
m = 8192
n = m
k = 8192
scope_min = -4
scope_max = 4
A = torch.empty(size=(m, k), dtype=torch.int8, device="cuda").random_(scope_min, scope_max)
B = torch.empty(size=(k, n), dtype=torch.int8, device="cuda").random_(scope_min, scope_max).T.contiguous().T
plan2 = MyGemm(
element_A=cutlass.DataType.s4,
layout_A=cutlass.LayoutType.RowMajor,
element_B=cutlass.DataType.s4,
layout_B=cutlass.LayoutType.ColumnMajor,
element_C=cutlass.DataType.bf16,
layout_C=cutlass.LayoutType.RowMajor,
element_D=cutlass.DataType.bf16,
element_accumulator=cutlass.DataType.s32,
)
def epilogue_scale(accum: cutlass.Tensor, row_scale, col_scale):
D = accum * col_scale * row_scale
return D
examples_tensors = {
"accum": cutlass.Tensor(element=cutlass.DataType.s32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
"row_scale": cutlass.Tensor(element=cutlass.DataType.bf16, shape=(m, 1), layout_tag=cutlass.LayoutType.RowMajor),
"col_scale": cutlass.Tensor(element=cutlass.DataType.bf16, shape=(1, n), layout_tag=cutlass.LayoutType.RowMajor),
"D": cutlass.Tensor(element=cutlass.DataType.bf16, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),
}
plan2.epilogue_visitor = cutlass.epilogue.trace(epilogue_scale, examples_tensors)
row_scale = torch.empty(size=(m, 1), dtype=torch.bfloat16, device="cuda").uniform_(scope_min, scope_max)
col_scale = torch.empty(size=(1, n), dtype=torch.bfloat16, device="cuda").uniform_(scope_min, scope_max)
C = torch.empty(A.shape[0], B.shape[1], dtype=torch.bfloat16, device="cuda")
plan2.run(
A, B, C, C,
visitor_args=dict(row_scale=row_scale, col_scale=col_scale, D=C),
_problem_size=(m, n, k * 2),
) Also, I have to comment out these lines for things to work cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h Lines 285 to 289 in 2991ce1
Update: switching to Update 2: this #1753 seems to solve the original issue |
What is your question?
Trying to understand the behavior of Gemm with a column-broadcasted bias vector epilogue.
When defining a device
GemmUniversalWithBroadcast
with the following config:I get a
core dump
whenever I try to run the above withM != K
. Running withM == N
, I get the correctGEMM
but the epilogue is broadcasted incorrectly (row-wise vs column-wise).When I run the above using
GemmUniversalAdapter
as the device handle, the op runs for allM
andN
. However, theA
andB
inputs transposed because of an internal transpose that the adapter does, while the epilogue op is performed correctly.Questions
GemmUniversalWithBroadcast
?GemmUniversalAdapter
transpose layouts internally?Repro
Here is a simple script for reproducing above.
GemmUniversalWithBroadcast
will fail to run withM != N
GemmUniversalWithBroadcast
runs withM == N
but epilogue incorrectGemmUniversalAdapter
runs, but with operandsA
andB
transposed.The text was updated successfully, but these errors were encountered: