Skip to content

[HLSL][DXIL] Implement refract intrinsic #147342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsSPIRVVK.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"

def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
19 changes: 19 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -2791,6 +2791,25 @@ class Sema final : public SemaBase {

void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);

bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::ArrayRef<
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
Checks);
bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);

static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType);
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType);

static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType);
/// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
/// TheCall is a constant expression.
bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_refract: {
Value *I = EmitScalarExpr(E->getArg(0));
Value *N = EmitScalarExpr(E->getArg(1));
Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->isFloatingType() &&
"refract operands must have a float representation");
assert(E->getArg(0)->getType()->isVectorType() &&
E->getArg(1)->getType()->isVectorType() &&
"refract I and N operands must be a vector");
return Builder.CreateIntrinsic(
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
static const bool Value = __is_arithmetic(T);
};

template <typename T> struct is_vector {
static const bool value = false;
};

template <typename T, int N> struct is_vector<vector<T, N>> {
static const bool value = true;
};

template <typename T, int N>
using HLSL_FIXED_VECTOR =
vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
T Mul = N * I;
T K = 1 - Eta * Eta * (1 - (Mul * Mul));
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
return select<T>(K < 0, static_cast<T>(0), Result);
}

template <typename T, typename U>
constexpr T refract_vec_impl(T I, T N, U Eta) {
#if (__has_builtin(__builtin_spirv_refract))
if (is_vector<T>::value) {
return __builtin_spirv_refract(I, N, Eta);
}
#else
T Mul = dot(N, I);
T K = 1 - Eta * Eta * (1 - Mul * Mul);
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
return select<T>(K < 0, static_cast<T>(0), Result);
#endif
}

template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
Expand Down
59 changes: 59 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// refract builtin
//===----------------------------------------------------------------------===//

/// \fn T refract(T I, T N, T eta)
/// \brief Returns a refraction using an entering ray, \a I, a surface
/// normal, \a N and refraction index \a eta
/// \param I The entering ray.
/// \param N The surface normal.
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
/// using the refraction index, \a eta, for the direction of the entering ray,
/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
/// if k < 0.0 the result is 0.0
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
///
/// I and N must already be normalized in order to achieve the desired result.
///
/// I and N must be a scalar or vector whose component type is
/// floating-point.
///
/// eta must be a 16-bit or 32-bit floating-point scalar.
///
/// Result type, the type of I, and the type of N must all be the same type.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
__detail::HLSL_FIXED_VECTOR<half, L> I,
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
return __detail::refract_vec_impl(I, N, eta);
}

template <int L>
const inline __detail::HLSL_FIXED_VECTOR<float, L>
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
return __detail::refract_vec_impl(I, N, eta);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
Expand Down
78 changes: 78 additions & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16151,3 +16151,81 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
}
}
}

bool Sema::CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::ArrayRef<
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
Checks) {
unsigned NumArgs = TheCall->getNumArgs();
if (Checks.size() == 1) {
// Apply the single check to all arguments
for (unsigned I = 0; I < NumArgs; ++I) {
Expr *Arg = TheCall->getArg(I);
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
}
return false;
} else if (Checks.size() == NumArgs) {
// Apply each check to the corresponding argument
for (unsigned I = 0; I < NumArgs; ++I) {
Expr *Arg = TheCall->getArg(I);
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
}
return false;
} else {
// Mismatch: error or fallback
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< NumArgs << Checks.size();
return true;
}
}

bool Sema::CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
}

bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType;
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType;
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

bool Sema::CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

if (VecTy || !PassedType->isHalfType() && !PassedType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}
73 changes: 64 additions & 9 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
return false;
}

static bool CheckAllArgTypesAreCorrect(
bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
clang::QualType PassedType)>
Check) {
for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
Expr *Arg = TheCall->getArg(I);
if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
llvm::ArrayRef<
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
Checks) {
unsigned NumArgs = TheCall->getNumArgs();
if (Checks.size() == 1) {
// Apply the single check to all arguments
for (unsigned I = 0; I < NumArgs; ++I) {
Expr *Arg = TheCall->getArg(I);
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
}
return false;
} else if (Checks.size() == NumArgs) {
// Apply each check to the corresponding argument
for (unsigned I = 0; I < NumArgs; ++I) {
Expr *Arg = TheCall->getArg(I);
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
return true;
}
return false;
} else {
// Mismatch: error or fallback
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< NumArgs << Checks.size();
return true;
}
return false;
}

bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall,
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
}

static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
Expand All @@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
return false;
}

static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType;
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
int ArgOrdinal,
clang::QualType PassedType) {
const auto *VecTy = PassedType->getAs<VectorType>();

clang::QualType BaseType =
PassedType->isVectorType()
? PassedType->castAs<clang::VectorType>()->getElementType()
: PassedType;
if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
<< /* half or float */ 2 << PassedType;
return false;
}

static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
unsigned ArgIndex) {
auto *Arg = TheCall->getArg(ArgIndex);
Expand Down
24 changes: 24 additions & 0 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,30 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_refract: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
Sema::CheckFloatOrHalfVectorsRepresentation,
Sema::CheckFloatOrHalfScalarRepresentation};
if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
llvm::ArrayRef(ChecksArr)))
return true;

ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->isFloatingType()) {
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
return true;
}

QualType RetTy = TheCall->getArg(0)->getType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_smoothstep: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expand Down
Loading