Skip to content

Commit 6ee3751

Browse files
authored
[RISCV] Correct type lowering of struct of fixed-vector array in VLS (#147173)
Currently, struct of fixed-vector array is flattened and lowered to scalable vector. However only struct of 1-element-fixed-vector array should be lowered that way, struct of fixed-vector array of length >1 should be lowered to vector tuple type. https://github.com/riscv-non-isa/riscv-elf-psabi-doc/pull/418/files#diff-3a934f00cffdb3e509722753126a2cf6082a7648ab3b9ca8cbb0e84f8a6a12edR555-R558
1 parent d14062e commit 6ee3751

File tree

3 files changed

+67
-91
lines changed

3 files changed

+67
-91
lines changed

clang/lib/CodeGen/Targets/RISCV.cpp

Lines changed: 59 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -441,98 +441,74 @@ bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
441441
// __attribute__((vector_size(64))) int d;
442442
// }
443443
//
444-
// Struct of 1 fixed-length vector is passed as a scalable vector.
445-
// Struct of >1 fixed-length vectors are passed as vector tuple.
446-
// Struct of 1 array of fixed-length vectors is passed as a scalable vector.
447-
// Otherwise, pass the struct indirectly.
448-
449-
if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty))) {
450-
unsigned NumElts = STy->getStructNumElements();
451-
if (NumElts > 8)
452-
return false;
444+
// 1. Struct of 1 fixed-length vector is passed as a scalable vector.
445+
// 2. Struct of >1 fixed-length vectors are passed as vector tuple.
446+
// 3. Struct of an array with 1 element of fixed-length vectors is passed as a
447+
// scalable vector.
448+
// 4. Struct of an array with >1 elements of fixed-length vectors is passed as
449+
// vector tuple.
450+
// 5. Otherwise, pass the struct indirectly.
451+
452+
llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty));
453+
if (!STy)
454+
return false;
453455

454-
auto *FirstEltTy = STy->getElementType(0);
455-
if (!STy->containsHomogeneousTypes())
456-
return false;
456+
unsigned NumElts = STy->getStructNumElements();
457+
if (NumElts > 8)
458+
return false;
457459

458-
// Check structure of fixed-length vectors and turn them into vector tuple
459-
// type if legal.
460-
if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
461-
if (NumElts == 1) {
462-
// Handle single fixed-length vector.
463-
VLSType = llvm::ScalableVectorType::get(
464-
FixedVecTy->getElementType(),
465-
llvm::divideCeil(FixedVecTy->getNumElements() *
466-
llvm::RISCV::RVVBitsPerBlock,
467-
ABIVLen));
468-
// Check registers needed <= 8.
469-
return llvm::divideCeil(
470-
FixedVecTy->getNumElements() *
471-
FixedVecTy->getElementType()->getScalarSizeInBits(),
472-
ABIVLen) <= 8;
473-
}
474-
// LMUL
475-
// = fixed-length vector size / ABIVLen
476-
// = 8 * I8EltCount / RVVBitsPerBlock
477-
// =>
478-
// I8EltCount
479-
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
480-
unsigned I8EltCount = llvm::divideCeil(
481-
FixedVecTy->getNumElements() *
482-
FixedVecTy->getElementType()->getScalarSizeInBits() *
483-
llvm::RISCV::RVVBitsPerBlock,
484-
ABIVLen * 8);
485-
VLSType = llvm::TargetExtType::get(
486-
getVMContext(), "riscv.vector.tuple",
487-
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
488-
I8EltCount),
489-
NumElts);
490-
// Check registers needed <= 8.
491-
return NumElts *
492-
llvm::divideCeil(
493-
FixedVecTy->getNumElements() *
494-
FixedVecTy->getElementType()->getScalarSizeInBits(),
495-
ABIVLen) <=
496-
8;
497-
}
460+
auto *FirstEltTy = STy->getElementType(0);
461+
if (!STy->containsHomogeneousTypes())
462+
return false;
498463

499-
// If elements are not fixed-length vectors, it should be an array.
464+
if (auto *ArrayTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
465+
// Only struct of single array is accepted
500466
if (NumElts != 1)
501467
return false;
468+
FirstEltTy = ArrayTy->getArrayElementType();
469+
NumElts = ArrayTy->getNumElements();
470+
}
502471

503-
// Check array of fixed-length vector and turn it into scalable vector type
504-
// if legal.
505-
if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
506-
unsigned NumArrElt = ArrTy->getNumElements();
507-
if (NumArrElt > 8)
508-
return false;
472+
auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy);
473+
if (!FixedVecTy)
474+
return false;
509475

510-
auto *ArrEltTy = dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType());
511-
if (!ArrEltTy)
512-
return false;
476+
// Check registers needed <= 8.
477+
if (NumElts * llvm::divideCeil(
478+
FixedVecTy->getNumElements() *
479+
FixedVecTy->getElementType()->getScalarSizeInBits(),
480+
ABIVLen) >
481+
8)
482+
return false;
513483

514-
// LMUL
515-
// = NumArrElt * fixed-length vector size / ABIVLen
516-
// = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
517-
// =>
518-
// ScalVecNumElts
519-
// = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
520-
// (ABIVLen * fixed-length vector elt size)
521-
// = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
522-
// ABIVLen
523-
unsigned ScalVecNumElts = llvm::divideCeil(
524-
NumArrElt * ArrEltTy->getNumElements() * llvm::RISCV::RVVBitsPerBlock,
525-
ABIVLen);
526-
VLSType = llvm::ScalableVectorType::get(ArrEltTy->getElementType(),
527-
ScalVecNumElts);
528-
// Check registers needed <= 8.
529-
return llvm::divideCeil(
530-
ScalVecNumElts *
531-
ArrEltTy->getElementType()->getScalarSizeInBits(),
532-
llvm::RISCV::RVVBitsPerBlock) <= 8;
533-
}
484+
// Turn them into scalable vector type or vector tuple type if legal.
485+
if (NumElts == 1) {
486+
// Handle single fixed-length vector.
487+
VLSType = llvm::ScalableVectorType::get(
488+
FixedVecTy->getElementType(),
489+
llvm::divideCeil(FixedVecTy->getNumElements() *
490+
llvm::RISCV::RVVBitsPerBlock,
491+
ABIVLen));
492+
return true;
534493
}
535-
return false;
494+
495+
// LMUL
496+
// = fixed-length vector size / ABIVLen
497+
// = 8 * I8EltCount / RVVBitsPerBlock
498+
// =>
499+
// I8EltCount
500+
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
501+
unsigned I8EltCount =
502+
llvm::divideCeil(FixedVecTy->getNumElements() *
503+
FixedVecTy->getElementType()->getScalarSizeInBits() *
504+
llvm::RISCV::RVVBitsPerBlock,
505+
ABIVLen * 8);
506+
VLSType = llvm::TargetExtType::get(
507+
getVMContext(), "riscv.vector.tuple",
508+
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
509+
I8EltCount),
510+
NumElts);
511+
return true;
536512
}
537513

538514
// Fixed-length RVV vectors are represented as scalable vectors in function

clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,14 @@ void __attribute__((riscv_vls_cc)) test_st_i32x4_arr1(struct st_i32x4_arr1 arg)
153153
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr1_256(<vscale x 1 x i32> %arg)
154154
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}
155155

156-
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr4(<vscale x 8 x i32> %arg)
156+
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) %arg)
157157
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
158-
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr4_256(<vscale x 4 x i32> %arg)
158+
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr4_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 4) %arg)
159159
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}
160160

161-
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr8(<vscale x 16 x i32> %arg)
161+
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
162162
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
163-
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr8_256(<vscale x 8 x i32> %arg)
163+
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @test_st_i32x4_arr8_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
164164
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}
165165

166166
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @test_st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)

clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ typedef int __attribute__((vector_size(256))) int32x64_t;
133133
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr1_25613st_i32x4_arr1(<vscale x 1 x i32> %arg)
134134
[[riscv::vls_cc(256)]] void test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}
135135

136-
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr413st_i32x4_arr4(<vscale x 8 x i32> %arg)
136+
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr413st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 8 x i8>, 4) %arg)
137137
[[riscv::vls_cc]] void test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
138-
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(<vscale x 4 x i32> %arg)
138+
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr4_25613st_i32x4_arr4(target("riscv.vector.tuple", <vscale x 4 x i8>, 4) %arg)
139139
[[riscv::vls_cc(256)]] void test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}
140140

141-
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr813st_i32x4_arr8(<vscale x 16 x i32> %arg)
141+
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z18test_st_i32x4_arr813st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
142142
[[riscv::vls_cc]] void test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
143-
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(<vscale x 8 x i32> %arg)
143+
// CHECK-LLVM: define dso_local riscv_vls_cc(256) void @_Z22test_st_i32x4_arr8_25613st_i32x4_arr8(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
144144
[[riscv::vls_cc(256)]] void test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}
145145

146146
// CHECK-LLVM: define dso_local riscv_vls_cc(128) void @_Z15test_st_i32x4x210st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)

0 commit comments

Comments
 (0)