Skip to content

Commit 556c846

Browse files
authored
[msan] Add handler for llvm.x86.avx512.mask.cvtps2dq.512 (#147377)
Propagate the shadow according to the writemask, instead of using the default strict handler. Updates the test added in #123980
1 parent cb7b069 commit 556c846

File tree

2 files changed

+70
-25
lines changed

2 files changed

+70
-25
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4342,6 +4342,61 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
43424342
setOriginForNaryOp(I);
43434343
}
43444344

4345+
// e.g., call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
4346+
// (<16 x float> a, <16 x i32> writethru, i16 mask,
4347+
// i32 rounding)
4348+
//
4349+
// dst[i] = mask[i] ? convert(a[i]) : writethru[i]
4350+
// dst_shadow[i] = mask[i] ? all_or_nothing(a_shadow[i]) : writethru_shadow[i]
4351+
// where all_or_nothing(x) is fully uninitialized if x has any
4352+
// uninitialized bits
4353+
void handleAVX512VectorConvertFPToInt(IntrinsicInst &I) {
4354+
IRBuilder<> IRB(&I);
4355+
4356+
assert(I.arg_size() == 4);
4357+
Value *A = I.getOperand(0);
4358+
Value *WriteThrough = I.getOperand(1);
4359+
Value *Mask = I.getOperand(2);
4360+
[[maybe_unused]] Value *RoundingMode = I.getOperand(3);
4361+
4362+
assert(isa<FixedVectorType>(A->getType()));
4363+
assert(A->getType()->isFPOrFPVectorTy());
4364+
4365+
assert(isa<FixedVectorType>(WriteThrough->getType()));
4366+
assert(WriteThrough->getType()->isIntOrIntVectorTy());
4367+
4368+
unsigned ANumElements =
4369+
cast<FixedVectorType>(A->getType())->getNumElements();
4370+
assert(ANumElements ==
4371+
cast<FixedVectorType>(WriteThrough->getType())->getNumElements());
4372+
4373+
assert(Mask->getType()->isIntegerTy());
4374+
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
4375+
4376+
assert(RoundingMode->getType()->isIntegerTy());
4377+
4378+
assert(I.getType() == WriteThrough->getType());
4379+
4380+
// Convert i16 mask to <16 x i1>
4381+
Mask = IRB.CreateBitCast(
4382+
Mask, FixedVectorType::get(IRB.getInt1Ty(), ANumElements));
4383+
4384+
Value *AShadow = getShadow(A);
4385+
/// For scalars:
4386+
/// Since they are converting from floating-point, the output is:
4387+
/// - fully uninitialized if *any* bit of the input is uninitialized
4388+
/// - fully ininitialized if all bits of the input are ininitialized
4389+
/// We apply the same principle on a per-element basis for vectors.
4390+
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(A)),
4391+
getShadowTy(A));
4392+
4393+
Value *WriteThroughShadow = getShadow(WriteThrough);
4394+
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
4395+
4396+
setShadow(&I, Shadow);
4397+
setOriginForNaryOp(I);
4398+
}
4399+
43454400
// Instrument BMI / BMI2 intrinsics.
43464401
// All of these intrinsics are Z = I(X, Y)
43474402
// where the types of all operands and the result match, and are either i32 or
@@ -5318,6 +5373,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
53185373
handleAVXVpermi2var(I);
53195374
break;
53205375

5376+
case Intrinsic::x86_avx512_mask_cvtps2dq_512: {
5377+
handleAVX512VectorConvertFPToInt(I);
5378+
break;
5379+
}
5380+
53215381
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
53225382
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
53235383
case Intrinsic::x86_avx512fp16_mask_mul_sh_round:

llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8152,34 +8152,19 @@ define <16 x i32>@test_int_x86_avx512_mask_cvt_ps2dq_512(<16 x float> %x0, <16 x
81528152
; CHECK-LABEL: @test_int_x86_avx512_mask_cvt_ps2dq_512(
81538153
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
81548154
; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
8155-
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
81568155
; CHECK-NEXT: call void @llvm.donothing()
8157-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
8158-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
8159-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i32> [[TMP2]] to i512
8160-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
8161-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
8162-
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i16 [[TMP3]], 0
8163-
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
8164-
; CHECK-NEXT: br i1 [[_MSOR3]], label [[TMP6:%.*]], label [[TMP7:%.*]], !prof [[PROF1]]
8165-
; CHECK: 6:
8166-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
8167-
; CHECK-NEXT: unreachable
8168-
; CHECK: 7:
8169-
; CHECK-NEXT: [[RES:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0:%.*]], <16 x i32> [[X1:%.*]], i16 [[X2:%.*]], i32 10)
8170-
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
8171-
; CHECK-NEXT: [[_MSCMP4:%.*]] = icmp ne i512 [[TMP8]], 0
8172-
; CHECK-NEXT: [[TMP9:%.*]] = bitcast <16 x i32> [[TMP2]] to i512
8173-
; CHECK-NEXT: [[_MSCMP5:%.*]] = icmp ne i512 [[TMP9]], 0
8174-
; CHECK-NEXT: [[_MSOR6:%.*]] = or i1 [[_MSCMP4]], [[_MSCMP5]]
8175-
; CHECK-NEXT: br i1 [[_MSOR6]], label [[TMP10:%.*]], label [[TMP11:%.*]], !prof [[PROF1]]
8176-
; CHECK: 10:
8177-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
8178-
; CHECK-NEXT: unreachable
8179-
; CHECK: 11:
8156+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast i16 [[X2:%.*]] to <16 x i1>
8157+
; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
8158+
; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i1> [[TMP4]] to <16 x i32>
8159+
; CHECK-NEXT: [[TMP6:%.*]] = select <16 x i1> [[TMP3]], <16 x i32> [[TMP5]], <16 x i32> [[TMP2]]
8160+
; CHECK-NEXT: [[RES:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0:%.*]], <16 x i32> [[X1:%.*]], i16 [[X2]], i32 10)
8161+
; CHECK-NEXT: [[TMP7:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
8162+
; CHECK-NEXT: [[TMP8:%.*]] = sext <16 x i1> [[TMP7]] to <16 x i32>
8163+
; CHECK-NEXT: [[TMP9:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP8]], <16 x i32> [[TMP2]]
81808164
; CHECK-NEXT: [[RES1:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0]], <16 x i32> [[X1]], i16 -1, i32 8)
8165+
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i32> [[TMP6]], [[TMP9]]
81818166
; CHECK-NEXT: [[RES2:%.*]] = add <16 x i32> [[RES]], [[RES1]]
8182-
; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
8167+
; CHECK-NEXT: store <16 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8
81838168
; CHECK-NEXT: ret <16 x i32> [[RES2]]
81848169
;
81858170
%res = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> %x0, <16 x i32> %x1, i16 %x2, i32 10)

0 commit comments

Comments
 (0)