Skip to content

[PTX] Added support for stmatrix migration #2801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: SYCLomatic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,10 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {

bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
// Address expression only support ld/st/red & atom instructions.
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
asmtok::op_prefetch, asmtok::op_red,
asmtok::op_cp, asmtok::op_ldmatrix)) {
if (!CurrInst ||
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp,
asmtok::op_ldmatrix, asmtok::op_stmatrix)) {
return SYCLGenError();
}
std::string Type;
Expand Down Expand Up @@ -635,7 +636,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
CanSuppressCast(Dst->getSymbol()))
OS() << llvm::formatv("{0}", Reg);
else if (CurrInst->is(asmtok::op_ldmatrix))
else if (CurrInst->is(asmtok::op_ldmatrix) ||
CurrInst->is(asmtok::op_stmatrix))
OS() << llvm::formatv("(uintptr_t){0}", Reg);
else
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
Expand Down Expand Up @@ -1376,6 +1378,64 @@ class SYCLGen : public SYCLGenBase {
return SYCLGenSuccess();
}

bool handle_stmatrix(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 1)
return SYCLGenError();

const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));

if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
return SYCLGenError();

const InlineAsmVectorExpr *VE;
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(0))) {
auto numOutputOperands = VE->getNumElements();
if (Inst->hasAttr(InstAttr::x1)) {
if (numOutputOperands != 1)
return SYCLGenError();
} else if (Inst->hasAttr(InstAttr::x2)) {
if (numOutputOperands != 2)
return SYCLGenError();
} else if (Inst->hasAttr(InstAttr::x4)) {
if (numOutputOperands != 4)
return SYCLGenError();
}
} else {
return SYCLGenError();
}

llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
CurrInst = Inst;
const auto *Dst =
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getOutputOperand());
if (!Dst)
return false;

OS() << MapNames::getDpctNamespace() << "experimental::matrix::stmatrix(";
if (emitStmt(Dst)) {
return SYCLGenError();
}
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
continue;
OS() << ", ";
if (emitStmt(VE->getElement(Inst)))
return SYCLGenError();
}
if (Inst->hasAttr(InstAttr::trans))
OS() << ", true";
OS() << ");";
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
if (KernelDecl) {
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
if (FuncInfo)
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
DpctGlobalInfo::getSubGroup(GAS));
}

return SYCLGenSuccess();
}

bool handle_mma(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 3)
return SYCLGenError();
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
/// therest are input operands.
SmallVector<InlineAsmExpr *, 4> InputOps;

SmallVector<InlineAsmExpr *, 4> OutputOps;

public:
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
SmallVector<AsmStateSpace, 4> AsmStateSpaces,
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,10 +736,11 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
InlineAsmExprResult
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {

// Vector size for ldmatrix are 1, 2, 4
// Vector size for ldmatrix/stmatrix are 1, 2, 4
// size(x) = 2 * sizeof(v).
InlineAsmVectorType::VecKind Kind;
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
if (Opcode->getTokenID() == asmtok::op_ldmatrix ||
Opcode->getTokenID() == asmtok::op_stmatrix) {
switch (Vec.size()) {
case 1:
Kind = InlineAsmVectorType::x1;
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/SrcAPI/APINames_ASM.inc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ ENTRY("sqrt", "sqrt", true, NO_FLAG, P1, "Successful")
ENTRY("st", "st", true, NO_FLAG, P1, "Partial")
ENTRY("stackrestore", "stackrestore", false, NO_FLAG, P1, "Comment")
ENTRY("stacksave", "stacksave", false, NO_FLAG, P1, "Comment")
ENTRY("stmatrix", "stmatrix", false, NO_FLAG, P1, "Comment")
ENTRY("stmatrix", "stmatrix", true, NO_FLAG, P1, "Successful")
ENTRY("sub", "sub", true, NO_FLAG, P1, "Partial")
ENTRY("subc", "subc", false, NO_FLAG, P1, "Comment")
ENTRY("suld", "suld", false, NO_FLAG, P1, "Comment")
Expand Down
162 changes: 161 additions & 1 deletion clang/runtime/dpct-rt/include/dpct/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ max(T1 a, T2 b) {
return sycl::fmax(static_cast<common_t>(a), static_cast<common_t>(b));
}

// pow functions overload.
// pow functions overstore.
inline float pow(const float a, const int b) { return sycl::pown(a, b); }
inline double pow(const double a, const int b) { return sycl::pown(a, b); }
inline float pow(const float a, const float b) { return sycl::pow(a, b); }
Expand Down Expand Up @@ -2218,6 +2218,166 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
ldmatrix(addr, m4, trans, 3);
}

/// Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
/// local memory per sub-group.
/// Requires the sub-group size of kernel calling this function to be 32.
/// 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
/// work items of sub-group contain the starting address of their respective
/// matrix row in 'addr'.
/// After distributing addresses to other work items, each of the 32 work items
/// store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
/// 'trans' specifies to perform a transposed/non-transposed store by each work
/// item like below
/// Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
/// ...
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
/// Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
/// row-0: wi0 wi4 wi8 ... wi28
/// row-1: wi0 wi4 wi8 ... wi28
/// ...
/// row-6: wi3 wi7 wi11 ... wi31
/// row-7: wi3 wi7 wi11 ... wi31
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
/// \param [in] addr The starting address of corresponding matrix row for a work
/// item in local memory
/// \param [in] m The private memory to store the matrix. It points to 2 b16
/// type elements.
/// \param [in] trans Indicates whether the matrix to be stored transposed
/// \param [in] mat The matrix index to be stored
template <typename T>
void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
int lane = sg.get_local_linear_id();

int lane_group8_row = lane / 8;
int lane_group8_col = lane % 8;

if (!trans) {
// calculate the source lane
int src_lane = 2 * lane_group8_row;
if (lane_group8_col >= 4)
src_lane += 1;

// Broadcast the address from the source lane
auto recv_addr_uintp =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);

// Cast the received address from uintptr_t to the type of 'm'
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);

// Non-transposed store
recv_addr[lane_group8_col % 4] = m;
} else {
// calculate the source lane
int src_lane = (lane % 4) * 2;

// Broadcast the address from the source lane
auto recv_addr_uintp_1 =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
auto recv_addr_uintp_2 =
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);

// Cast the received address from uintptr_t to 'half *'
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);

// Split the 32-bit value of 'm' into two 16-bits
sycl::half *val = reinterpret_cast<sycl::half *>(&m);

// Transposed store
int index = lane / 4;
recv_addr_1[index] = val[0];
recv_addr_2[index] = val[1];
}
}

/// Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
/// local memory per sub-group.
/// Requires the sub-group size of kernel calling this function to be 32.
/// The first 16 work items of sub-group contain the starting address of their
/// respective matrix row in 'addr'.
/// After distributing addresses to other work items, each of the 32 work items
/// store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
/// bytes.
/// 'trans' specifies to perform a transposed/non-transposed store by each work
/// item like below
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
/// ...
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
/// row-0: wi0 wi4 wi8 ... wi28
/// row-1: wi0 wi4 wi8 ... wi28
/// ...
/// row-6: wi3 wi7 wi11 ... wi31
/// row-7: wi3 wi7 wi11 ... wi31
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
/// \param [in] addr The starting address of corresponding matrix row for a work
/// item in local memory
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
/// to 2 b16 type elements.
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
/// to 2 b16 type elements.
/// \param [in] trans Indicates whether the matrix to be stored transposed
template <typename T>
void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
// Store 1st matrix
stmatrix(addr, m1, trans, 0);
// Store 2nd matrix
stmatrix(addr, m2, trans, 1);
}

/// Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
/// local memory per sub-group.
/// Requires the sub-group size of kernel calling this function to be 32.
/// Each work item of sub-group contains the starting address of their
/// respective matrix row in 'addr'.
/// After distributing addresses to other work items, each of the 32 work items
/// store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
/// of 512 bytes.
/// 'trans' specifies to perform a transposed/non-transposed store by each work
/// item like below
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
/// ...
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
/// row-0: wi0 wi4 wi8 ... wi28
/// row-1: wi0 wi4 wi8 ... wi28
/// ...
/// row-6: wi3 wi7 wi11 ... wi31
/// row-7: wi3 wi7 wi11 ... wi31
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
/// \param [in] addr The starting address of corresponding matrix row for a work
/// item in local memory
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
/// to 2 b16 type elements.
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
/// to 2 b16 type elements.
/// \param [in] m3 The private memory to store the data of 3rd matrix. It points
/// to 2 b16 type elements.
/// \param [in] m4 The private memory to store the data of 4th matrix. It points
/// to 2 b16 type elements.
/// \param [in] trans Indicates whether the matrix to be stored transposed
template <typename T>
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false) {
// Store 1st matrix
stmatrix(addr, m1, trans, 0);
// Store 2nd matrix
stmatrix(addr, m2, trans, 1);
// Store 3rd matrix
stmatrix(addr, m3, trans, 2);
// Store 4th matrix
stmatrix(addr, m4, trans, 3);
}

/// A helper struct that defines the pack type for the input matrix fragments
/// of mma() function based on the type of input matrix fragments.
/// The MMAType struct is specialized for different types of input matrices.
Expand Down
Loading