Skip to content

Commit 5fa2416

Browse files
committed
[AArch64][SME] Add SME read/write intrinsics that map to the mova instruction
This patch adds implementations for the read/write SME ACLE intrinsics: @llvm.aarch64.sme.read.horiz @llvm.aarch64.sme.read.vert @llvm.aarch64.sme.write.horiz @llvm.aarch64.sme.write.vert These all map to the SME mova instruction. Differential Revision: https://reviews.llvm.org/D127414
1 parent 4c2bccf commit 5fa2416

File tree

6 files changed

+1175
-0
lines changed

6 files changed

+1175
-0
lines changed

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,4 +2631,22 @@ let TargetPrefix = "aarch64" in {
26312631
[], [llvm_i32_ty, llvm_ptr_ty]>;
26322632
def int_aarch64_sme_str : DefaultAttrsIntrinsic<
26332633
[], [llvm_i32_ty, llvm_ptr_ty]>;
2634+
2635+
class SME_TileToVector_Intrinsic
2636+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
2637+
[LLVMMatchType<0>, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_i64_ty, llvm_i32_ty]>;
2638+
class SME_VectorToTile_Intrinsic
2639+
: DefaultAttrsIntrinsic<[],
2640+
[llvm_i64_ty, llvm_i32_ty, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
2641+
llvm_anyvector_ty]>;
2642+
2643+
def int_aarch64_sme_read_horiz : SME_TileToVector_Intrinsic;
2644+
def int_aarch64_sme_read_vert : SME_TileToVector_Intrinsic;
2645+
def int_aarch64_sme_write_horiz : SME_VectorToTile_Intrinsic;
2646+
def int_aarch64_sme_write_vert : SME_VectorToTile_Intrinsic;
2647+
2648+
def int_aarch64_sme_readq_horiz : SME_TileToVector_Intrinsic;
2649+
def int_aarch64_sme_readq_vert : SME_TileToVector_Intrinsic;
2650+
def int_aarch64_sme_writeq_horiz : SME_VectorToTile_Intrinsic;
2651+
def int_aarch64_sme_writeq_vert : SME_VectorToTile_Intrinsic;
26342652
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,6 +2357,24 @@ AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const {
23572357
return BB;
23582358
}
23592359

2360+
MachineBasicBlock *
2361+
AArch64TargetLowering::EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg,
2362+
MachineInstr &MI,
2363+
MachineBasicBlock *BB) const {
2364+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2365+
MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
2366+
2367+
MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
2368+
MIB.addReg(BaseReg + MI.getOperand(0).getImm());
2369+
MIB.add(MI.getOperand(1)); // Slice index register
2370+
MIB.add(MI.getOperand(2)); // Slice index offset
2371+
MIB.add(MI.getOperand(3)); // pg
2372+
MIB.add(MI.getOperand(4)); // zn
2373+
2374+
MI.eraseFromParent(); // The pseudo is gone now.
2375+
return BB;
2376+
}
2377+
23602378
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
23612379
MachineInstr &MI, MachineBasicBlock *BB) const {
23622380
switch (MI.getOpcode()) {
@@ -2409,6 +2427,36 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
24092427
return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB);
24102428
case AArch64::LDR_ZA_PSEUDO:
24112429
return EmitFill(MI, BB);
2430+
case AArch64::INSERT_MXIPZ_H_PSEUDO_B:
2431+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_B, AArch64::ZAB0, MI,
2432+
BB);
2433+
case AArch64::INSERT_MXIPZ_H_PSEUDO_H:
2434+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_H, AArch64::ZAH0, MI,
2435+
BB);
2436+
case AArch64::INSERT_MXIPZ_H_PSEUDO_S:
2437+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_S, AArch64::ZAS0, MI,
2438+
BB);
2439+
case AArch64::INSERT_MXIPZ_H_PSEUDO_D:
2440+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_D, AArch64::ZAD0, MI,
2441+
BB);
2442+
case AArch64::INSERT_MXIPZ_H_PSEUDO_Q:
2443+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_H_Q, AArch64::ZAQ0, MI,
2444+
BB);
2445+
case AArch64::INSERT_MXIPZ_V_PSEUDO_B:
2446+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_B, AArch64::ZAB0, MI,
2447+
BB);
2448+
case AArch64::INSERT_MXIPZ_V_PSEUDO_H:
2449+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_H, AArch64::ZAH0, MI,
2450+
BB);
2451+
case AArch64::INSERT_MXIPZ_V_PSEUDO_S:
2452+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_S, AArch64::ZAS0, MI,
2453+
BB);
2454+
case AArch64::INSERT_MXIPZ_V_PSEUDO_D:
2455+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_D, AArch64::ZAD0, MI,
2456+
BB);
2457+
case AArch64::INSERT_MXIPZ_V_PSEUDO_Q:
2458+
return EmitInsertVectorToTile(AArch64::INSERT_MXIPZ_V_Q, AArch64::ZAQ0, MI,
2459+
BB);
24122460
}
24132461
}
24142462

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,10 @@ class AArch64TargetLowering : public TargetLowering {
561561
MachineBasicBlock *BB) const;
562562
MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const;
563563

564+
MachineBasicBlock *EmitInsertVectorToTile(unsigned Opc, unsigned BaseReg,
565+
MachineInstr &MI,
566+
MachineBasicBlock *BB) const;
567+
564568
MachineBasicBlock *
565569
EmitInstrWithCustomInserter(MachineInstr &MI,
566570
MachineBasicBlock *MBB) const override;

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,30 @@ multiclass sme_vector_to_tile_aliases<Instruction inst,
622622
(inst tile_ty:$ZAd, MatrixIndexGPR32Op12_15:$Rv, imm_ty:$imm, PPR3bAny:$Pg, zpr_ty:$Zn), 1>;
623623
}
624624

625+
multiclass sme_vector_to_tile_patterns<Instruction inst, ValueType zpr_vt,
626+
ValueType ppr_vt, Operand imm_ty,
627+
Operand offset_ty,
628+
SDPatternOperator op,
629+
ComplexPattern tileslice> {
630+
def : Pat<(op imm_ty:$tile, MatrixIndexGPR32Op12_15:$idx,
631+
(ppr_vt PPR3bAny:$pg), (zpr_vt ZPRAny:$zn)),
632+
(inst imm_ty:$tile, $idx, 0, $pg, $zn)>;
633+
let AddedComplexity = 1 in {
634+
def : Pat<(op imm_ty:$tile, (i32 (tileslice MatrixIndexGPR32Op12_15:$idx,
635+
offset_ty:$imm)),
636+
(ppr_vt PPR3bAny:$pg), (zpr_vt ZPRAny:$zn)),
637+
(inst imm_ty:$tile, $idx, $imm, $pg, $zn)>;
638+
}
639+
}
640+
641+
class sme_mova_insert_pseudo
642+
: Pseudo<(outs), (ins i64imm:$tile, MatrixIndexGPR32Op12_15:$idx,
643+
i64imm:$imm, PPR3bAny:$pg, ZPRAny:$zn), []>,
644+
Sched<[]> {
645+
// Translated to the actual instructions in AArch64ISelLowering.cpp
646+
let usesCustomInserter = 1;
647+
}
648+
625649
multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
626650
def _B : sme_vector_to_tile_inst<0b0, 0b00, !if(is_col, TileVectorOpV8,
627651
TileVectorOpH8),
@@ -661,6 +685,14 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
661685
let Inst{3-0} = ZAd;
662686
}
663687

688+
// Pseudo instructions for lowering intrinsics, using immediates instead of
689+
// tile registers.
690+
def _PSEUDO_B : sme_mova_insert_pseudo;
691+
def _PSEUDO_H : sme_mova_insert_pseudo;
692+
def _PSEUDO_S : sme_mova_insert_pseudo;
693+
def _PSEUDO_D : sme_mova_insert_pseudo;
694+
def _PSEUDO_Q : sme_mova_insert_pseudo;
695+
664696
defm : sme_vector_to_tile_aliases<!cast<Instruction>(NAME # _B),
665697
!if(is_col, TileVectorOpV8,
666698
TileVectorOpH8),
@@ -681,6 +713,62 @@ multiclass sme_vector_v_to_tile<string mnemonic, bit is_col> {
681713
!if(is_col, TileVectorOpV128,
682714
TileVectorOpH128),
683715
ZPR128, sme_elm_idx0_0>;
716+
717+
defvar op = !if(is_col, int_aarch64_sme_write_vert,
718+
int_aarch64_sme_write_horiz);
719+
720+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_B),
721+
nxv16i8, nxv16i1, sme_elm_idx0_0, imm0_15,
722+
op, tileslice8>;
723+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_H),
724+
nxv8i16, nxv8i1, sme_elm_idx0_1, imm0_7,
725+
op, tileslice16>;
726+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_H),
727+
nxv8f16, nxv8i1, sme_elm_idx0_1, imm0_7,
728+
op, tileslice16>;
729+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_H),
730+
nxv8bf16, nxv8i1, sme_elm_idx0_1, imm0_7,
731+
op, tileslice16>;
732+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_S),
733+
nxv4i32, nxv4i1, sme_elm_idx0_3, imm0_3,
734+
op, tileslice32>;
735+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_S),
736+
nxv4f32, nxv4i1, sme_elm_idx0_3, imm0_3,
737+
op, tileslice32>;
738+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_D),
739+
nxv2i64, nxv2i1, sme_elm_idx0_7, imm0_1,
740+
op, tileslice64>;
741+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_D),
742+
nxv2f64, nxv2i1, sme_elm_idx0_7, imm0_1,
743+
op, tileslice64>;
744+
745+
defvar opq = !if(is_col, int_aarch64_sme_writeq_vert,
746+
int_aarch64_sme_writeq_horiz);
747+
748+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
749+
nxv16i8, nxv16i1, sme_elm_idx0_15,
750+
sme_elm_idx0_0, opq, tileslice128>;
751+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
752+
nxv8i16, nxv8i1, sme_elm_idx0_15,
753+
sme_elm_idx0_0, opq, tileslice128>;
754+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
755+
nxv8f16, nxv8i1, sme_elm_idx0_15,
756+
sme_elm_idx0_0, opq, tileslice128>;
757+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
758+
nxv8bf16, nxv8i1, sme_elm_idx0_15,
759+
sme_elm_idx0_0, opq, tileslice128>;
760+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
761+
nxv4i32, nxv4i1, sme_elm_idx0_15,
762+
sme_elm_idx0_0, opq, tileslice128>;
763+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
764+
nxv4f32, nxv4i1, sme_elm_idx0_15,
765+
sme_elm_idx0_0, opq, tileslice128>;
766+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
767+
nxv2i64, nxv2i1, sme_elm_idx0_15,
768+
sme_elm_idx0_0, opq, tileslice128>;
769+
defm : sme_vector_to_tile_patterns<!cast<Instruction>(NAME # _PSEUDO_Q),
770+
nxv2f64, nxv2i1, sme_elm_idx0_15,
771+
sme_elm_idx0_0, opq, tileslice128>;
684772
}
685773

686774
multiclass sme_vector_to_tile<string mnemonic> {
@@ -722,6 +810,23 @@ multiclass sme_tile_to_vector_aliases<Instruction inst, ZPRRegOp zpr_ty,
722810
(inst zpr_ty:$Zd, PPR3bAny:$Pg, tile_ty:$ZAn, MatrixIndexGPR32Op12_15:$Rv, imm_ty:$imm), 1>;
723811
}
724812

813+
multiclass sme_tile_to_vector_patterns<Instruction inst, ValueType zpr_vt,
814+
ValueType ppr_vt, Operand offset_ty,
815+
ComplexPattern imm2tile,
816+
ComplexPattern tileslice,
817+
SDPatternOperator op> {
818+
def : Pat<(zpr_vt (op (zpr_vt ZPRAny:$passthru), (ppr_vt PPR3bAny:$pg),
819+
(imm2tile untyped:$tile), MatrixIndexGPR32Op12_15:$idx)),
820+
(inst $passthru, $pg, $tile, $idx, 0)>;
821+
let AddedComplexity = 1 in {
822+
def : Pat<(zpr_vt (op (zpr_vt ZPRAny:$passthru), (ppr_vt PPR3bAny:$pg),
823+
(imm2tile untyped:$tile),
824+
(i32 (tileslice MatrixIndexGPR32Op12_15:$idx,
825+
offset_ty:$imm)))),
826+
(inst $passthru, $pg, $tile, $idx, $imm)>;
827+
}
828+
}
829+
725830
multiclass sme_tile_to_vector_v<string mnemonic, bit is_col> {
726831
def _B : sme_tile_to_vector_inst<0b0, 0b00, ZPR8, !if(is_col, TileVectorOpV8,
727832
TileVectorOpH8),
@@ -775,6 +880,62 @@ multiclass sme_tile_to_vector_v<string mnemonic, bit is_col> {
775880
defm : sme_tile_to_vector_aliases<!cast<Instruction>(NAME # _Q), ZPR128,
776881
!if(is_col, TileVectorOpV128,
777882
TileVectorOpH128), sme_elm_idx0_0>;
883+
884+
defvar op = !if(is_col, int_aarch64_sme_read_vert,
885+
int_aarch64_sme_read_horiz);
886+
887+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _B),
888+
nxv16i8, nxv16i1, imm0_15,
889+
imm_to_tile8, tileslice8, op>;
890+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _H),
891+
nxv8i16, nxv8i1, imm0_7,
892+
imm_to_tile16, tileslice16, op>;
893+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _H),
894+
nxv8f16, nxv8i1, imm0_7,
895+
imm_to_tile16, tileslice16, op>;
896+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _H),
897+
nxv8bf16, nxv8i1, imm0_7,
898+
imm_to_tile16, tileslice16, op>;
899+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _S),
900+
nxv4i32, nxv4i1, imm0_3,
901+
imm_to_tile32, tileslice32, op>;
902+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _S),
903+
nxv4f32, nxv4i1, imm0_3,
904+
imm_to_tile32, tileslice32, op>;
905+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _D),
906+
nxv2i64, nxv2i1, imm0_1,
907+
imm_to_tile64, tileslice64, op>;
908+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _D),
909+
nxv2f64, nxv2i1, imm0_1,
910+
imm_to_tile64, tileslice64, op>;
911+
912+
defvar opq = !if(is_col, int_aarch64_sme_readq_vert,
913+
int_aarch64_sme_readq_horiz);
914+
915+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
916+
nxv16i8, nxv16i1, sme_elm_idx0_0,
917+
imm_to_tile128, tileslice128, opq>;
918+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
919+
nxv8i16, nxv8i1, sme_elm_idx0_0,
920+
imm_to_tile128, tileslice128, opq>;
921+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
922+
nxv8f16, nxv8i1, sme_elm_idx0_0,
923+
imm_to_tile128, tileslice128, opq>;
924+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
925+
nxv8bf16, nxv8i1, sme_elm_idx0_0,
926+
imm_to_tile128, tileslice128, opq>;
927+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
928+
nxv4i32, nxv4i1, sme_elm_idx0_0,
929+
imm_to_tile128, tileslice128, opq>;
930+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
931+
nxv4f32, nxv4i1, sme_elm_idx0_0,
932+
imm_to_tile128, tileslice128, opq>;
933+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
934+
nxv2i64, nxv2i1, sme_elm_idx0_0,
935+
imm_to_tile128, tileslice128, opq>;
936+
defm : sme_tile_to_vector_patterns<!cast<Instruction>(NAME # _Q),
937+
nxv2f64, nxv2i1, sme_elm_idx0_0,
938+
imm_to_tile128, tileslice128, opq>;
778939
}
779940

780941
multiclass sme_tile_to_vector<string mnemonic> {

0 commit comments

Comments
 (0)