Skip to content

[HLSL][DXIL] Implement refract intrinsic #147342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsSPIRVVK.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"

def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_refract: {
Value *I = EmitScalarExpr(E->getArg(0));
Value *N = EmitScalarExpr(E->getArg(1));
Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->isFloatingType() &&
"refract operands must have a float representation");
assert(E->getArg(0)->getType()->isVectorType() &&
E->getArg(1)->getType()->isVectorType() &&
"refract I and N operands must be a vector");
return Builder.CreateIntrinsic(
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
static const bool Value = __is_arithmetic(T);
};

template <typename T> struct is_vector {
static const bool value = false;
};

template <typename T, int N> struct is_vector<vector<T, N>> {
static const bool value = true;
};

template <typename T, int N>
using HLSL_FIXED_VECTOR =
vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
#if (__has_builtin(__builtin_spirv_refract))
if (is_vector<T>::value)
return __builtin_spirv_refract(I, N, Eta);
#endif
T Mul = dot(N, I);
T K = 1 - Eta * Eta * (1 - Mul * Mul);
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
return select<T>(K < 0, static_cast<T>(0), Result);
}

template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
Expand Down
59 changes: 59 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// refract builtin
//===----------------------------------------------------------------------===//

/// \fn T refract(T I, T N, T eta)
/// \brief Returns a refraction using an entering ray, \a I, a surface
/// normal, \a N and refraction index \a eta
/// \param I The entering ray.
/// \param N The surface normal.
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
/// using the refraction index, \a eta, for the direction of the entering ray,
/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
/// if k < 0.0 the result is 0.0
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
///
/// I and N must already be normalized in order to achieve the desired result.
///
/// I and N must be a scalar or vector whose component type is
/// floating-point.
///
/// eta must be a 16-bit or 32-bit floating-point scalar.
///
/// Result type, the type of I, and the type of N must all be the same type.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
__detail::HLSL_FIXED_VECTOR<half, L> I,
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
const inline __detail::HLSL_FIXED_VECTOR<float, L>
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
return __detail::refract_impl(I, N, eta);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3995,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
}
Init = C;
return true;
}
}
77 changes: 77 additions & 0 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,59 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
return false;
}

static bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::ArrayRef<
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
Checks) {
unsigned NumArgs = TheCall->getNumArgs();
assert(Checks.size() == NumArgs &&
"Wrong number of checks for Number of args.");
// Apply each check to the corresponding argument
for (unsigned I = 0; I < NumArgs; ++I) {
Expr *Arg = TheCall->getArg(I);
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
}
return false;
}

static bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
return CheckAllArgTypesAreCorrect(
S, TheCall,
SmallVector<
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>, 4>(
TheCall->getNumArgs(), Check));
}

static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType;
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type()))
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

static std::optional<int>
processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
ExprResult Arg =
Expand Down Expand Up @@ -235,6 +288,30 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_refract: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
ChecksArr[] = {CheckFloatOrHalfRepresentation,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on if this is meant to handle scalar values as well as vectors? Looking at the code gen, we are asserting that the first two arguments are vectors, but here we allow them to be scalars. @farzonl Does this handle only the case where the first two arguments are vectors?

If that is the case 'CheckFloatOrHalfRepresentation' should be updated to only check for vectors of half or float and should probably be renamed to 'CheckFloatOrHalfVecRepresentation'.

Copy link
Author

@raoanag raoanag Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to support vector of size 1, which is implicitly converted to scalar.
Also HLSL_FIXED_VECTOR only supports Vector of N > 1.

Hence, even though first 2 args are described as vector or N = 1 they are seen as scalar

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your explanation is slightly incorrect, but it does seem the __builtin_spirv_refract is reachable with a scalar value. In this case the codegen assertions are wrong and I will leave a comment there about updating them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll review this tomorrow, but supporting scalars here seems wrong. I’m almost 100% sure that spirv only supports vectors and that our semantics should match that.

Copy link
Contributor

@spall spall Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Refract spirv op says both scalar and vector are supported. https://registry.khronos.org/SPIR-V/specs/unified1/GLSL.std.450.pdf (search for Refract).
But it is up to us what we want to allow.

CheckFloatOrHalfRepresentation,
CheckFloatOrHalfScalarRepresentation};
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
llvm::ArrayRef(ChecksArr)))
return true;

ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->isFloatingType()) {
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
return true;
}

QualType RetTy = TheCall->getArg(0)->getType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_smoothstep: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expand Down
Loading
Loading