Skip to content

Commit 17d1fbb

Browse files
committed
revert arg checks for other intrinsics
1 parent 86d2f84 commit 17d1fbb

File tree

8 files changed

+103
-87
lines changed

8 files changed

+103
-87
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2791,11 +2791,6 @@ class Sema final : public SemaBase {
27912791

27922792
void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
27932793

2794-
/// CheckVectorArgs - Check that the arguments of a vector function call
2795-
bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
2796-
2797-
bool CheckVectorArgs(CallExpr *TheCall);
2798-
27992794
bool CheckAllArgTypesAreCorrect(
28002795
Sema *S, CallExpr *TheCall,
28012796
llvm::ArrayRef<
@@ -2806,15 +2801,15 @@ class Sema final : public SemaBase {
28062801
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
28072802

28082803
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2809-
int ArgOrdinal,
2810-
clang::QualType PassedType);
2811-
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
28122804
int ArgOrdinal,
28132805
clang::QualType PassedType);
2806+
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2807+
int ArgOrdinal,
2808+
clang::QualType PassedType);
28142809

28152810
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
2816-
int ArgOrdinal,
2817-
clang::QualType PassedType);
2811+
int ArgOrdinal,
2812+
clang::QualType PassedType);
28182813
/// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
28192814
/// TheCall is a constant expression.
28202815
bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,6 @@ constexpr T refract_vec_impl(T I, T N, U Eta) {
9292
#endif
9393
}
9494

95-
/*
96-
template <typename T, int L>
97-
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
98-
#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
99-
return __builtin_spirv_refract(I, N, Eta);
100-
#else
101-
T Mul = dot(N, I);
102-
vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
103-
vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
104-
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
105-
#endif
106-
}
107-
108-
*/
109-
11095
template <typename T> constexpr T fmod_impl(T X, T Y) {
11196
#if !defined(__DIRECTX__)
11297
return __builtin_elementwise_fmod(X, Y);

clang/lib/Sema/SemaChecking.cpp

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16152,28 +16152,6 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
1615216152
}
1615316153
}
1615416154

16155-
bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
16156-
for (unsigned i = 0; i < NumArgsToCheck; ++i) {
16157-
ExprResult Arg = TheCall->getArg(i);
16158-
QualType ArgTy = Arg.get()->getType();
16159-
auto *VTy = ArgTy->getAs<VectorType>();
16160-
if (VTy == nullptr) {
16161-
SemaRef.Diag(Arg.get()->getBeginLoc(),
16162-
diag::err_typecheck_convert_incompatible)
16163-
<< ArgTy
16164-
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
16165-
<< 0 << 0;
16166-
return true;
16167-
}
16168-
}
16169-
return false;
16170-
}
16171-
16172-
bool Sema::CheckVectorArgs(CallExpr *TheCall) {
16173-
return CheckVectorArgs(TheCall, TheCall->getNumArgs());
16174-
}
16175-
16176-
1617716155
bool Sema::CheckAllArgTypesAreCorrect(
1617816156
Sema *S, CallExpr *TheCall,
1617916157
llvm::ArrayRef<
@@ -16211,8 +16189,8 @@ bool Sema::CheckAllArgTypesAreCorrect(
1621116189
}
1621216190

1621316191
bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
16214-
int ArgOrdinal,
16215-
clang::QualType PassedType) {
16192+
int ArgOrdinal,
16193+
clang::QualType PassedType) {
1621616194
clang::QualType BaseType =
1621716195
PassedType->isVectorType()
1621816196
? PassedType->castAs<clang::VectorType>()->getElementType()
@@ -16225,8 +16203,8 @@ bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
1622516203
}
1622616204

1622716205
bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
16228-
int ArgOrdinal,
16229-
clang::QualType PassedType) {
16206+
int ArgOrdinal,
16207+
clang::QualType PassedType) {
1623016208
const auto *VecTy = PassedType->getAs<VectorType>();
1623116209

1623216210
clang::QualType BaseType =
@@ -16240,19 +16218,14 @@ bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
1624016218
return false;
1624116219
}
1624216220

16243-
bool Sema::CheckFloatOrHalfScalarRepresentation(
16244-
Sema *S, SourceLocation Loc,
16245-
int ArgOrdinal,
16246-
clang::QualType PassedType) {
16221+
bool Sema::CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
16222+
int ArgOrdinal,
16223+
clang::QualType PassedType) {
1624716224
const auto *VecTy = PassedType->getAs<VectorType>();
1624816225

16249-
clang::QualType BaseType =
16250-
PassedType->isVectorType()
16251-
? PassedType->castAs<clang::VectorType>()->getElementType()
16252-
: PassedType;
16253-
if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
16226+
if (VecTy || !PassedType->isHalfType() && !PassedType->isFloat32Type())
1625416227
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16255-
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
16228+
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
1625616229
<< /* half or float */ 2 << PassedType;
1625716230
return false;
1625816231
}

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,13 +2452,13 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
24522452
}
24532453

24542454
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2455-
int ArgOrdinal,
2456-
clang::QualType PassedType) {
2455+
int ArgOrdinal,
2456+
clang::QualType PassedType) {
24572457
const auto *VecTy = PassedType->getAs<VectorType>();
24582458

2459-
clang::QualType BaseType =
2459+
clang::QualType BaseType =
24602460
PassedType->isVectorType()
2461-
? PassedType->castAs<clang::VectorType>()->getElementType()
2461+
? PassedType->castAs<clang::VectorType>()->getElementType()
24622462
: PassedType;
24632463
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
24642464
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -157,25 +157,81 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
157157
if (SemaRef.checkArgCount(TheCall, 2))
158158
return true;
159159

160-
// Use the helper function to check both arguments
161-
if (SemaRef.CheckVectorArgs(TheCall))
160+
ExprResult A = TheCall->getArg(0);
161+
QualType ArgTyA = A.get()->getType();
162+
auto *VTyA = ArgTyA->getAs<VectorType>();
163+
if (VTyA == nullptr) {
164+
SemaRef.Diag(A.get()->getBeginLoc(),
165+
diag::err_typecheck_convert_incompatible)
166+
<< ArgTyA
167+
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
168+
<< 0 << 0;
162169
return true;
170+
}
163171

164-
QualType RetTy =
165-
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
172+
ExprResult B = TheCall->getArg(1);
173+
QualType ArgTyB = B.get()->getType();
174+
auto *VTyB = ArgTyB->getAs<VectorType>();
175+
if (VTyB == nullptr) {
176+
SemaRef.Diag(A.get()->getBeginLoc(),
177+
diag::err_typecheck_convert_incompatible)
178+
<< ArgTyB
179+
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
180+
<< 0 << 0;
181+
return true;
182+
}
183+
184+
QualType RetTy = VTyA->getElementType();
166185
TheCall->setType(RetTy);
167186
break;
168187
}
169188
case SPIRV::BI__builtin_spirv_length: {
170189
if (SemaRef.checkArgCount(TheCall, 1))
171190
return true;
191+
ExprResult A = TheCall->getArg(0);
192+
QualType ArgTyA = A.get()->getType();
193+
auto *VTy = ArgTyA->getAs<VectorType>();
194+
if (VTy == nullptr) {
195+
SemaRef.Diag(A.get()->getBeginLoc(),
196+
diag::err_typecheck_convert_incompatible)
197+
<< ArgTyA
198+
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
199+
<< 0 << 0;
200+
return true;
201+
}
202+
QualType RetTy = VTy->getElementType();
203+
TheCall->setType(RetTy);
204+
break;
205+
}
206+
case SPIRV::BI__builtin_spirv_reflect: {
207+
if (SemaRef.checkArgCount(TheCall, 2))
208+
return true;
172209

173-
// Use the helper function to check the argument
174-
if (SemaRef.CheckVectorArgs(TheCall))
210+
ExprResult A = TheCall->getArg(0);
211+
QualType ArgTyA = A.get()->getType();
212+
auto *VTyA = ArgTyA->getAs<VectorType>();
213+
if (VTyA == nullptr) {
214+
SemaRef.Diag(A.get()->getBeginLoc(),
215+
diag::err_typecheck_convert_incompatible)
216+
<< ArgTyA
217+
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
218+
<< 0 << 0;
219+
return true;
220+
}
221+
222+
ExprResult B = TheCall->getArg(1);
223+
QualType ArgTyB = B.get()->getType();
224+
auto *VTyB = ArgTyB->getAs<VectorType>();
225+
if (VTyB == nullptr) {
226+
SemaRef.Diag(A.get()->getBeginLoc(),
227+
diag::err_typecheck_convert_incompatible)
228+
<< ArgTyB
229+
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
230+
<< 0 << 0;
175231
return true;
232+
}
176233

177-
QualType RetTy =
178-
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
234+
QualType RetTy = ArgTyA;
179235
TheCall->setType(RetTy);
180236
break;
181237
}
@@ -203,18 +259,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
203259
TheCall->setType(RetTy);
204260
break;
205261
}
206-
case SPIRV::BI__builtin_spirv_reflect: {
207-
if (SemaRef.checkArgCount(TheCall, 2))
208-
return true;
209-
210-
// Use the helper function to check both arguments
211-
if (SemaRef.CheckVectorArgs(TheCall))
212-
return true;
213-
214-
QualType RetTy = TheCall->getArg(0)->getType();
215-
TheCall->setType(RetTy);
216-
break;
217-
}
218262
case SPIRV::BI__builtin_spirv_smoothstep: {
219263
if (SemaRef.checkArgCount(TheCall, 3))
220264
return true;

clang/test/CodeGenHLSL/builtins/reflect.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,4 @@ float3 test_reflect_float3(float3 I, float3 N) {
174174
//
175175
float4 test_reflect_float4(float4 I, float4 N) {
176176
return reflect(I, N);
177-
}
177+
}

clang/test/CodeGenHLSL/intrinsic.ll

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
2+
define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(half noundef nofpclass(nan inf) %I, half noundef nofpclass(nan inf) %N, half noundef nofpclass(nan inf) %ETA) local_unnamed_addr #0 {
3+
entry:
4+
%mul.i = fmul reassoc nnan ninf nsz arcp afn half %N, %I
5+
%mul1.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %ETA
6+
%mul2.i = fmul reassoc nnan ninf nsz arcp afn half %mul.i, %mul.i
7+
%sub.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul2.i
8+
%mul3.i = fmul reassoc nnan ninf nsz arcp afn half %mul1.i, %sub.i
9+
%sub4.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul3.i
10+
%mul5.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %I
11+
%mul6.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %mul.i
12+
%0 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %sub4.i)
13+
%add.i = fadd reassoc nnan ninf nsz arcp afn half %0, %mul6.i
14+
%mul7.i = fmul reassoc nnan ninf nsz arcp afn half %add.i, %N
15+
%sub8.i = fsub reassoc nnan ninf nsz arcp afn half %mul5.i, %mul7.i
16+
%cmp.i = fcmp reassoc nnan ninf nsz arcp afn olt half %sub4.i, 0xH0000
17+
%hlsl.select.i = select reassoc nnan ninf nsz arcp afn i1 %cmp.i, half 0xH0000, half %sub8.i
18+
ret half %hlsl.select.i
19+
}

llvm/lib/IR/IRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,4 +1262,4 @@ IRBuilderDefaultInserter::~IRBuilderDefaultInserter() = default;
12621262
IRBuilderCallbackInserter::~IRBuilderCallbackInserter() = default;
12631263
IRBuilderFolder::~IRBuilderFolder() = default;
12641264
void ConstantFolder::anchor() {}
1265-
void NoFolder::anchor() {}
1265+
void NoFolder::anchor() {}

0 commit comments

Comments
 (0)