Skip to content

Commit 7bdad40

Browse files
committed
Update Refract intrinsic
1 parent 17d1fbb commit 7bdad40

File tree

8 files changed

+80
-185
lines changed

8 files changed

+80
-185
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2791,25 +2791,6 @@ class Sema final : public SemaBase {
27912791

27922792
void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
27932793

2794-
bool CheckAllArgTypesAreCorrect(
2795-
Sema *S, CallExpr *TheCall,
2796-
llvm::ArrayRef<
2797-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
2798-
Checks);
2799-
bool CheckAllArgTypesAreCorrect(
2800-
Sema *S, CallExpr *TheCall,
2801-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
2802-
2803-
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2804-
int ArgOrdinal,
2805-
clang::QualType PassedType);
2806-
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2807-
int ArgOrdinal,
2808-
clang::QualType PassedType);
2809-
2810-
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
2811-
int ArgOrdinal,
2812-
clang::QualType PassedType);
28132794
/// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
28142795
/// TheCall is a constant expression.
28152796
bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
@@ -15506,4 +15487,4 @@ void Sema::PragmaStack<Sema::AlignPackInfo>::Act(SourceLocation PragmaLocation,
1550615487

1550715488
} // end namespace clang
1550815489

15509-
#endif
15490+
#endif

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ template <typename T, int N> struct is_vector<vector<T, N>> {
5555

5656
template <typename T, int N>
5757
using HLSL_FIXED_VECTOR =
58-
vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
58+
vector<__detail::enable_if_t<(N >= 1 && N <= 4), T>, N>;
5959

6060
} // namespace __detail
6161
} // namespace hlsl

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,15 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
7171
#endif
7272
}
7373

74-
template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
75-
T Mul = N * I;
76-
T K = 1 - Eta * Eta * (1 - (Mul * Mul));
77-
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
78-
return select<T>(K < 0, static_cast<T>(0), Result);
79-
}
80-
81-
template <typename T, typename U>
82-
constexpr T refract_vec_impl(T I, T N, U Eta) {
74+
template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
8375
#if (__has_builtin(__builtin_spirv_refract))
84-
if (is_vector<T>::value) {
76+
if (is_vector<T>::value)
8577
return __builtin_spirv_refract(I, N, Eta);
86-
}
87-
#else
78+
#endif
8879
T Mul = dot(N, I);
8980
T K = 1 - Eta * Eta * (1 - Mul * Mul);
9081
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
9182
return select<T>(K < 0, static_cast<T>(0), Result);
92-
#endif
9383
}
9484

9585
template <typename T> constexpr T fmod_impl(T X, T Y) {

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,14 @@ _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
524524
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
525525
__detail::HLSL_FIXED_VECTOR<half, L> I,
526526
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
527-
return __detail::refract_vec_impl(I, N, eta);
527+
return __detail::refract_impl(I, N, eta);
528528
}
529529

530530
template <int L>
531531
const inline __detail::HLSL_FIXED_VECTOR<float, L>
532532
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
533533
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
534-
return __detail::refract_vec_impl(I, N, eta);
534+
return __detail::refract_impl(I, N, eta);
535535
}
536536

537537
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaChecking.cpp

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -16151,81 +16151,3 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
1615116151
}
1615216152
}
1615316153
}
16154-
16155-
bool Sema::CheckAllArgTypesAreCorrect(
16156-
Sema *S, CallExpr *TheCall,
16157-
llvm::ArrayRef<
16158-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
16159-
Checks) {
16160-
unsigned NumArgs = TheCall->getNumArgs();
16161-
if (Checks.size() == 1) {
16162-
// Apply the single check to all arguments
16163-
for (unsigned I = 0; I < NumArgs; ++I) {
16164-
Expr *Arg = TheCall->getArg(I);
16165-
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
16166-
return true;
16167-
}
16168-
return false;
16169-
} else if (Checks.size() == NumArgs) {
16170-
// Apply each check to the corresponding argument
16171-
for (unsigned I = 0; I < NumArgs; ++I) {
16172-
Expr *Arg = TheCall->getArg(I);
16173-
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
16174-
return true;
16175-
}
16176-
return false;
16177-
} else {
16178-
// Mismatch: error or fallback
16179-
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
16180-
<< NumArgs << Checks.size();
16181-
return true;
16182-
}
16183-
}
16184-
16185-
bool Sema::CheckAllArgTypesAreCorrect(
16186-
Sema *S, CallExpr *TheCall,
16187-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
16188-
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
16189-
}
16190-
16191-
bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
16192-
int ArgOrdinal,
16193-
clang::QualType PassedType) {
16194-
clang::QualType BaseType =
16195-
PassedType->isVectorType()
16196-
? PassedType->castAs<clang::VectorType>()->getElementType()
16197-
: PassedType;
16198-
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
16199-
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16200-
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
16201-
<< /* half or float */ 2 << PassedType;
16202-
return false;
16203-
}
16204-
16205-
bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
16206-
int ArgOrdinal,
16207-
clang::QualType PassedType) {
16208-
const auto *VecTy = PassedType->getAs<VectorType>();
16209-
16210-
clang::QualType BaseType =
16211-
PassedType->isVectorType()
16212-
? PassedType->castAs<clang::VectorType>()->getElementType()
16213-
: PassedType;
16214-
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
16215-
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16216-
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
16217-
<< /* half or float */ 2 << PassedType;
16218-
return false;
16219-
}
16220-
16221-
bool Sema::CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
16222-
int ArgOrdinal,
16223-
clang::QualType PassedType) {
16224-
const auto *VecTy = PassedType->getAs<VectorType>();
16225-
16226-
if (VecTy || !PassedType->isHalfType() && !PassedType->isFloat32Type())
16227-
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16228-
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
16229-
<< /* half or float */ 2 << PassedType;
16230-
return false;
16231-
}

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 10 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,40 +2401,17 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
24012401
return false;
24022402
}
24032403

2404-
bool CheckAllArgTypesAreCorrect(
2404+
static bool CheckAllArgTypesAreCorrect(
24052405
Sema *S, CallExpr *TheCall,
2406-
llvm::ArrayRef<
2407-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
2408-
Checks) {
2409-
unsigned NumArgs = TheCall->getNumArgs();
2410-
if (Checks.size() == 1) {
2411-
// Apply the single check to all arguments
2412-
for (unsigned I = 0; I < NumArgs; ++I) {
2413-
Expr *Arg = TheCall->getArg(I);
2414-
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2415-
return true;
2416-
}
2417-
return false;
2418-
} else if (Checks.size() == NumArgs) {
2419-
// Apply each check to the corresponding argument
2420-
for (unsigned I = 0; I < NumArgs; ++I) {
2421-
Expr *Arg = TheCall->getArg(I);
2422-
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2423-
return true;
2424-
}
2425-
return false;
2426-
} else {
2427-
// Mismatch: error or fallback
2428-
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2429-
<< NumArgs << Checks.size();
2430-
return true;
2406+
llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2407+
clang::QualType PassedType)>
2408+
Check) {
2409+
for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2410+
Expr *Arg = TheCall->getArg(I);
2411+
if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2412+
return true;
24312413
}
2432-
}
2433-
2434-
bool CheckAllArgTypesAreCorrect(
2435-
Sema *S, CallExpr *TheCall,
2436-
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
2437-
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
2414+
return false;
24382415
}
24392416

24402417
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2451,38 +2428,6 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
24512428
return false;
24522429
}
24532430

2454-
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2455-
int ArgOrdinal,
2456-
clang::QualType PassedType) {
2457-
const auto *VecTy = PassedType->getAs<VectorType>();
2458-
2459-
clang::QualType BaseType =
2460-
PassedType->isVectorType()
2461-
? PassedType->castAs<clang::VectorType>()->getElementType()
2462-
: PassedType;
2463-
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
2464-
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2465-
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
2466-
<< /* half or float */ 2 << PassedType;
2467-
return false;
2468-
}
2469-
2470-
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
2471-
int ArgOrdinal,
2472-
clang::QualType PassedType) {
2473-
const auto *VecTy = PassedType->getAs<VectorType>();
2474-
2475-
clang::QualType BaseType =
2476-
PassedType->isVectorType()
2477-
? PassedType->castAs<clang::VectorType>()->getElementType()
2478-
: PassedType;
2479-
if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
2480-
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2481-
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2482-
<< /* half or float */ 2 << PassedType;
2483-
return false;
2484-
}
2485-
24862431
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
24872432
unsigned ArgIndex) {
24882433
auto *Arg = TheCall->getArg(ArgIndex);
@@ -4050,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
40503995
}
40513996
Init = C;
40523997
return true;
4053-
}
3998+
}

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,59 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
4646
return false;
4747
}
4848

49+
static bool CheckAllArgTypesAreCorrect(
50+
Sema *S, CallExpr *TheCall,
51+
llvm::ArrayRef<
52+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
53+
Checks) {
54+
unsigned NumArgs = TheCall->getNumArgs();
55+
assert(Checks.size() == NumArgs &&
56+
"Wrong number of checks for Number of args.");
57+
// Apply each check to the corresponding argument
58+
for (unsigned I = 0; I < NumArgs; ++I) {
59+
Expr *Arg = TheCall->getArg(I);
60+
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
61+
return true;
62+
}
63+
return false;
64+
}
65+
66+
static bool CheckAllArgTypesAreCorrect(
67+
Sema *S, CallExpr *TheCall,
68+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
69+
return CheckAllArgTypesAreCorrect(
70+
S, TheCall,
71+
SmallVector<
72+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>, 4>(
73+
TheCall->getNumArgs(), Check));
74+
}
75+
76+
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
77+
int ArgOrdinal,
78+
clang::QualType PassedType) {
79+
clang::QualType BaseType =
80+
PassedType->isVectorType()
81+
? PassedType->castAs<clang::VectorType>()->getElementType()
82+
: PassedType;
83+
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
84+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
85+
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
86+
<< /* half or float */ 2 << PassedType;
87+
return false;
88+
}
89+
90+
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
91+
int ArgOrdinal,
92+
clang::QualType PassedType) {
93+
const auto *VecTy = PassedType->getAs<VectorType>();
94+
95+
if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type()))
96+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
97+
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
98+
<< /* half or float */ 2 << PassedType;
99+
return false;
100+
}
101+
49102
static std::optional<int>
50103
processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
51104
ExprResult Arg =
@@ -240,11 +293,11 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
240293
return true;
241294

242295
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
243-
ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
244-
Sema::CheckFloatOrHalfVectorsRepresentation,
245-
Sema::CheckFloatOrHalfScalarRepresentation};
246-
if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
247-
llvm::ArrayRef(ChecksArr)))
296+
ChecksArr[] = {CheckFloatOrHalfRepresentation,
297+
CheckFloatOrHalfRepresentation,
298+
CheckFloatOrHalfScalarRepresentation};
299+
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
300+
llvm::ArrayRef(ChecksArr)))
248301
return true;
249302

250303
ExprResult C = TheCall->getArg(2);

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ let TargetPrefix = "spv" in {
7575
[IntrNoMem] >;
7676
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
7777
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
78-
def int_spv_refract : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_anyfloat_ty], [IntrNoMem]>;
78+
def int_spv_refract
79+
: DefaultAttrsIntrinsic<[LLVMMatchType<0>],
80+
[llvm_anyfloat_ty, LLVMMatchType<0>,
81+
llvm_anyfloat_ty],
82+
[IntrNoMem]>;
7983
def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
8084
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8185
def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

0 commit comments

Comments
 (0)