-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[HLSL][DXIL] Implement refract
intrinsic
#147342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-hlsl @llvm/pr-subscribers-backend-spir-v Author: None (raoanag) Changes
Resolves #99153 Patch is 57.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147342.diff 19 Files Affected:
diff --git a/clang/include/clang/Basic/BuiltinsSPIRVVK.td b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
index 61cc0343c415e..5dc3c7588cd2a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVVK.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
+def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 3fe26f950ad51..105ab804fffd0 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
+ /// CheckVectorArgs - Check that the arguments of a vector function call
+ bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
+
+ bool CheckVectorArgs(CallExpr *TheCall);
+
+ bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall,
+ llvm::ArrayRef<
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+ Checks);
+ bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall,
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
+
+ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType);
+ static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType);
+
+ static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType);
/// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
/// TheCall is a constant expression.
bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 0687485cd3f80..1c63e04f757c7 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
+ case SPIRV::BI__builtin_spirv_refract: {
+ Value *I = EmitScalarExpr(E->getArg(0));
+ Value *N = EmitScalarExpr(E->getArg(1));
+ Value *eta = EmitScalarExpr(E->getArg(2));
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasFloatingRepresentation() &&
+ E->getArg(2)->getType()->isFloatingType() &&
+ "refract operands must have a float representation");
+ assert(E->getArg(0)->getType()->isVectorType() &&
+ E->getArg(1)->getType()->isVectorType() &&
+ "refract I and N operands must be a vector");
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+ ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+ }
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 80c4900121dfb..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
static const bool Value = __is_arithmetic(T);
};
+template <typename T> struct is_vector {
+ static const bool value = false;
+};
+
+template <typename T, int N> struct is_vector<vector<T, N>> {
+ static const bool value = true;
+};
+
template <typename T, int N>
using HLSL_FIXED_VECTOR =
vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 4eb7b8f45c85a..f6acb1cea2594 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,42 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
+ T Mul = N * I;
+ T K = 1 - Eta * Eta * (1 - (Mul * Mul));
+ T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+ return select<T>(K < 0, static_cast<T>(0), Result);
+}
+
+template <typename T, typename U>
+constexpr T refract_vec_impl(T I, T N, U Eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+ if (is_vector<T>::value) {
+ return __builtin_spirv_refract(I, N, Eta);
+ }
+#else
+ T Mul = dot(N, I);
+ T K = 1 - Eta * Eta * (1 - Mul * Mul);
+ T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+ return select<T>(K < 0, static_cast<T>(0), Result);
+#endif
+}
+
+/*
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
+#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
+ return __builtin_spirv_refract(I, N, Eta);
+#else
+ T Mul = dot(N, I);
+ vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
+ vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+ return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
+#endif
+}
+
+*/
+
template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ea880105fac3b..8c262ffce25f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+ __detail::is_same<half, T>::value,
+ T> refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+ __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+refract(T I, T N, T eta) {
+ return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+ __detail::HLSL_FIXED_VECTOR<half, L> I,
+ __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+ __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+ return __detail::refract_vec_impl(I, N, eta);
+}
+
//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index dd5b710d7e1d4..98bca59f14ecd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
}
}
}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+ for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+ ExprResult Arg = TheCall->getArg(i);
+ QualType ArgTy = Arg.get()->getType();
+ auto *VTy = ArgTy->getAs<VectorType>();
+ if (VTy == nullptr) {
+ SemaRef.Diag(Arg.get()->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTy
+ << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+ << 0 << 0;
+ return true;
+ }
+ }
+ return false;
+}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall) {
+ return CheckVectorArgs(TheCall, TheCall->getNumArgs());
+}
+
+
+bool Sema::CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall,
+ llvm::ArrayRef<
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+ Checks) {
+ unsigned NumArgs = TheCall->getNumArgs();
+ if (Checks.size() == 1) {
+ // Apply the single check to all arguments
+ for (unsigned I = 0; I < NumArgs; ++I) {
+ Expr *Arg = TheCall->getArg(I);
+ if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+ return true;
+ }
+ return false;
+ } else if (Checks.size() == NumArgs) {
+ // Apply each check to the corresponding argument
+ for (unsigned I = 0; I < NumArgs; ++I) {
+ Expr *Arg = TheCall->getArg(I);
+ if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+ return true;
+ }
+ return false;
+ } else {
+ // Mismatch: error or fallback
+ S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+ << NumArgs << Checks.size();
+ return true;
+ }
+}
+
+bool Sema::CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall,
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+ return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+}
+
+bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType) {
+ clang::QualType BaseType =
+ PassedType->isVectorType()
+ ? PassedType->castAs<clang::VectorType>()->getElementType()
+ : PassedType;
+ if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+ return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+ << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+ << /* half or float */ 2 << PassedType;
+ return false;
+}
+
+bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType) {
+ const auto *VecTy = PassedType->getAs<VectorType>();
+
+ clang::QualType BaseType =
+ PassedType->isVectorType()
+ ? PassedType->castAs<clang::VectorType>()->getElementType()
+ : PassedType;
+ if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+ return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+ << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+ << /* half or float */ 2 << PassedType;
+ return false;
+}
+
+bool Sema::CheckFloatOrHalfScalarRepresentation(
+ Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType) {
+ const auto *VecTy = PassedType->getAs<VectorType>();
+
+ clang::QualType BaseType =
+ PassedType->isVectorType()
+ ? PassedType->castAs<clang::VectorType>()->getElementType()
+ : PassedType;
+ if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+ return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+ << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+ << /* half or float */ 2 << PassedType;
+ return false;
+}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bad357b50929b..991d330edfb6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
return false;
}
-static bool CheckAllArgTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
- llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
- clang::QualType PassedType)>
- Check) {
- for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
- Expr *Arg = TheCall->getArg(I);
- if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
- return true;
+ llvm::ArrayRef<
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+ Checks) {
+ unsigned NumArgs = TheCall->getNumArgs();
+ if (Checks.size() == 1) {
+ // Apply the single check to all arguments
+ for (unsigned I = 0; I < NumArgs; ++I) {
+ Expr *Arg = TheCall->getArg(I);
+ if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+ return true;
+ }
+ return false;
+ } else if (Checks.size() == NumArgs) {
+ // Apply each check to the corresponding argument
+ for (unsigned I = 0; I < NumArgs; ++I) {
+ Expr *Arg = TheCall->getArg(I);
+ if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+ return true;
+ }
+ return false;
+ } else {
+ // Mismatch: error or fallback
+ S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+ << NumArgs << Checks.size();
+ return true;
}
- return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+ Sema *S, CallExpr *TheCall,
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+ return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
}
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
return false;
}
+static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType) {
+ const auto *VecTy = PassedType->getAs<VectorType>();
+
+ clang::QualType BaseType =
+ PassedType->isVectorType()
+ ? PassedType->castAs<clang::VectorType>()->getElementType()
+ : PassedType;
+ if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+ return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+ << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+ << /* half or float */ 2 << PassedType;
+ return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+ int ArgOrdinal,
+ clang::QualType PassedType) {
+ const auto *VecTy = PassedType->getAs<VectorType>();
+
+ clang::QualType BaseType =
+ PassedType->isVectorType()
+ ? PassedType->castAs<clang::VectorType>()->getElementType()
+ : PassedType;
+ if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+ return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+ << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+ << /* half or float */ 2 << PassedType;
+ return false;
+}
+
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
auto *Arg = TheCall->getArg(ArgIndex);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index c27d3fed2b990..1b4093065a63a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -157,81 +157,61 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
if (SemaRef.checkArgCount(TheCall, 2))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTyA = ArgTyA->getAs<VectorType>();
- if (VTyA == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
+ // Use the helper function to check both arguments
+ if (SemaRef.CheckVectorArgs(TheCall))
return true;
- }
- ExprResult B = TheCall->getArg(1);
- QualType ArgTyB = B.get()->getType();
- auto *VTyB = ArgTyB->getAs<VectorType>();
- if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyB
- << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
- << 0 << 0;
- return true;
- }
-
- QualType RetTy = VTyA->getElementType();
+ QualType RetTy =
+ TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTy = ArgTyA->getAs<VectorType>();
- if (VTy == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
+
+ // Use the helper function to check the argument
+ if (SemaRef.CheckVectorArgs(TheCall))
return true;
- }
- QualType RetTy = VTy->getElementType();
+
+ QualType RetTy =
+ TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
TheCall->setType(RetTy);
break;
}
- case SPIRV::BI__builtin_spirv_reflect: {
- if (SemaRef.checkArgCount(TheCall, 2))
+ case SPIRV::BI__builtin_spirv_refract: {
+ if (SemaRef.checkArgCount(TheCall, 3))
return true;
- ExprResult A = TheCall->getArg(0);
- QualType ArgTyA = A.get()->getType();
- auto *VTyA = ArgTyA->getAs<VectorType>();
- if (VTyA == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyA
- << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
- << 0 << 0;
+ llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
+ ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
+ Sema::CheckFloatOrHalfVectorsRepresentation,
+ Sema::CheckFloatOrHalfScalarRepresentation};
+ if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+ llvm::ArrayRef(ChecksArr)))
return true;
- }
- ExprResult B = TheCall->getArg(1);
- QualType ArgTyB = B.get()->getType();
- auto *VTyB = ArgTyB->getAs<VectorType>();
- if (VTyB == nullptr) {
- SemaRef.Diag(A.get()->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << ArgTyB
- << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
- << 0 << 0;
+ ExprResult C = TheCall->getArg(2);
+ QualType ArgTyC = C.get()->getType();
+ if (!ArgTyC->isFloatingType...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR isn't ready yet.
515ecda
to
729fbf3
Compare
@@ -3995,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) { | |||
} | |||
Init = C; | |||
return true; | |||
} | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add the newline back.
return false; | ||
} | ||
|
||
static bool CheckAllArgTypesAreCorrect( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline, yes you should delete this function since it is unused.
|
||
ExprResult C = TheCall->getArg(2); | ||
QualType ArgTyC = C.get()->getType(); | ||
if (!ArgTyC->isFloatingType()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the purpose of this check? I thought the call to 'CheckAllArgTypesAreCorrect' handled this?
clang::QualType PassedType) { | ||
const auto *VecTy = PassedType->getAs<VectorType>(); | ||
|
||
if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need the VecTy check here, isHalfType() and isFloat32Type() should both return false if PassedType is a vector.
return true; | ||
|
||
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> | ||
ChecksArr[] = {CheckFloatOrHalfRepresentation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused on if this is meant to handle scalar values as well as vectors? Looking at the code gen, we are asserting that the first two arguments are vectors, but here we allow them to be scalars. @farzonl Does this handle only the case where the first two arguments are vectors?
If that is the case 'CheckFloatOrHalfRepresentation' should be updated to only check for vectors of half or float and should probably be renamed to 'CheckFloatOrHalfVecRepresentation'.
Resolves #99153