Skip to content

[HLSL][DXIL] Implement refract intrinsic #147342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

raoanag
Copy link

@raoanag raoanag commented Jul 7, 2025

  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153

Copy link

github-actions bot commented Jul 7, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@raoanag raoanag marked this pull request as ready for review July 7, 2025 16:27
Copy link

github-actions bot commented Jul 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-backend-spir-v

Author: None (raoanag)

Changes
  • Implement refract using HLSL source in hlsl_intrinsics.h
  • Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td
  • Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp
  • Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp
  • Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl
  • Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c
  • Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
  • Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c
  • Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td
  • In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic.
  • Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
  • Check for what OpenCL support is needed.

Resolves #99153


Patch is 57.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147342.diff

19 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsSPIRVVK.td (+1)
  • (modified) clang/include/clang/Sema/Sema.h (+24)
  • (modified) clang/lib/CodeGen/TargetBuiltins/SPIR.cpp (+15)
  • (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+8)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+36)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+59)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+105)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+64-9)
  • (modified) clang/lib/Sema/SemaSPIRV.cpp (+36-56)
  • (modified) clang/test/CodeGenHLSL/builtins/reflect.hlsl (+1-1)
  • (added) clang/test/CodeGenHLSL/builtins/refract.hlsl (+271)
  • (added) clang/test/CodeGenSPIRV/Builtins/refract.c (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl (+74)
  • (added) clang/test/SemaSPIRV/BuiltIns/refract-errors.c (+23)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+2)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll (+36)
  • (added) llvm/test/CodeGen/SPIRV/opencl/refract-error.ll (+12)
diff --git a/clang/include/clang/Basic/BuiltinsSPIRVVK.td b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
index 61cc0343c415e..5dc3c7588cd2a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVVK.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
 
 def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
+def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 3fe26f950ad51..105ab804fffd0 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
+  /// CheckVectorArgs - Check that the arguments of a vector function call
+  bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
+
+  bool CheckVectorArgs(CallExpr *TheCall);
+
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::ArrayRef<
+          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+          Checks);
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
+
+  static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                            int ArgOrdinal,
+                                            clang::QualType PassedType);
+  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType);
+
+  static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                int ArgOrdinal,
+                                                clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 0687485cd3f80..1c63e04f757c7 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->isFloatingType() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 80c4900121dfb..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
+template <typename T> struct is_vector {
+  static const bool value = false;
+};
+
+template <typename T, int N> struct is_vector<vector<T, N>> {
+  static const bool value = true;
+};
+
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
     vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 4eb7b8f45c85a..f6acb1cea2594 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,42 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
+  T Mul = N * I;
+  T K = 1 - Eta * Eta * (1 - (Mul * Mul));
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+}
+
+template <typename T, typename U>
+constexpr T refract_vec_impl(T I, T N, U Eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  if (is_vector<T>::value) {
+    return __builtin_spirv_refract(I, N, Eta);
+  }
+#else
+  T Mul = dot(N, I);
+  T K = 1 - Eta * Eta * (1 - Mul * Mul);
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+#endif
+}
+
+/*
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
+#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
+  return __builtin_spirv_refract(I, N, Eta);
+#else
+  T Mul = dot(N, I);
+  vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
+  vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
+#endif
+}
+
+*/
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ea880105fac3b..8c262ffce25f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index dd5b710d7e1d4..98bca59f14ecd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
     }
   }
 }
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+    ExprResult Arg = TheCall->getArg(i);
+    QualType ArgTy = Arg.get()->getType();
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(Arg.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTy
+          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+  }
+  return false;
+}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall) {
+  return CheckVectorArgs(TheCall, TheCall->getNumArgs());
+}
+
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
+  }
+}
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+}
+
+bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                                  int ArgOrdinal,
+                                                  clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfScalarRepresentation(
+    Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bad357b50929b..991d330edfb6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckAllArgTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
-                            clang::QualType PassedType)>
-        Check) {
-  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
-    Expr *Arg = TheCall->getArg(I);
-    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-      return true;
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
   }
-  return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
 }
 
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
+static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType = 
+      PassedType->isVectorType()
+        ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index c27d3fed2b990..1b4093065a63a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -157,81 +157,61 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check both arguments
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    QualType RetTy = VTyA->getElementType();
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
   case SPIRV::BI__builtin_spirv_length: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTy = ArgTyA->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+
+    // Use the helper function to check the argument
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
-    }
-    QualType RetTy = VTy->getElementType();
+
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
-  case SPIRV::BI__builtin_spirv_reflect: {
-    if (SemaRef.checkArgCount(TheCall, 2))
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
+        ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfScalarRepresentation};
+    if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                           llvm::ArrayRef(ChecksArr)))
       return true;
-    }
 
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->isFloatingType...
[truncated]

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR isn't ready yet.

@raoanag raoanag force-pushed the user/raoanag/refract branch from 515ecda to 729fbf3 Compare July 8, 2025 23:03
@@ -3995,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
}
Init = C;
return true;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the newline back.

return false;
}

static bool CheckAllArgTypesAreCorrect(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline, yes you should delete this function since it is unused.


ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->isFloatingType()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of this check? I thought the call to 'CheckAllArgTypesAreCorrect' handled this?

clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need the VecTy check here, isHalfType() and isFloat32Type() should both return false if PassedType is a vector.

return true;

llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
ChecksArr[] = {CheckFloatOrHalfRepresentation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on if this is meant to handle scalar values as well as vectors? Looking at the code gen, we are asserting that the first two arguments are vectors, but here we allow them to be scalars. @farzonl Does this handle only the case where the first two arguments are vectors?

If that is the case 'CheckFloatOrHalfRepresentation' should be updated to only check for vectors of half or float and should probably be renamed to 'CheckFloatOrHalfVecRepresentation'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement the refract HLSL Function
5 participants