Skip to content

Commit da80234

Browse files
author
devsh
committed
Merge remote-tracking branch 'remotes/origin/more_fft_utils'
2 parents d7a9e13 + e8f46dd commit da80234

File tree

16 files changed

+956
-173
lines changed

16 files changed

+956
-173
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifndef _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_BITREVERSE_INCLUDED_
3+
4+
5+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
6+
7+
namespace nbl
8+
{
9+
namespace hlsl
10+
{
11+
12+
template<typename T, uint16_t Bits NBL_FUNC_REQUIRES(is_unsigned_v<T>&& Bits <= sizeof(T) * 8)
13+
/**
14+
* @brief Takes the binary representation of `value` as a string of `Bits` bits and returns a value of the same type resulting from reversing the string
15+
*
16+
* @tparam T Type of the value to operate on.
17+
* @tparam Bits The length of the string of bits used to represent `value`.
18+
*
19+
* @param [in] value The value to bitreverse.
20+
*/
21+
T bitReverseAs(T value)
22+
{
23+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - Bits));
24+
}
25+
26+
template<typename T NBL_FUNC_REQUIRES(is_unsigned_v<T>)
27+
/**
28+
* @brief Takes the binary representation of `value` and returns a value of the same type resulting from reversing the string of bits as if it was `bits` long.
29+
* Keep in mind `bits` cannot exceed `8 * sizeof(T)`.
30+
*
31+
* @tparam T type of the value to operate on.
32+
*
33+
* @param [in] value The value to bitreverse.
34+
* @param [in] bits The length of the string of bits used to represent `value`.
35+
*/
36+
T bitReverseAs(T value, uint16_t bits)
37+
{
38+
return bitReverse<T>(value) >> promote<T, scalar_type_t<T> >(scalar_type_t <T>(sizeof(T) * 8 - bits));
39+
}
40+
41+
42+
}
43+
}
44+
45+
46+
47+
#endif

include/nbl/builtin/hlsl/complex.hlsl

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,47 @@
55
#ifndef _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
66
#define _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
77

8-
#include "nbl/builtin/hlsl/functional.hlsl"
9-
#include "nbl/builtin/hlsl/cpp_compat/promote.hlsl"
8+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
9+
#include <nbl/builtin/hlsl/functional.hlsl>
10+
11+
using namespace nbl::hlsl;
12+
13+
// -------------------------------------- CPP VERSION ------------------------------------
14+
#ifndef __HLSL_VERSION
15+
16+
#include <complex>
17+
18+
namespace nbl
19+
{
20+
namespace hlsl
21+
{
22+
23+
template<typename Scalar>
24+
using complex_t = std::complex<Scalar>;
25+
26+
// Fast mul by i
27+
template<typename Scalar>
28+
complex_t<Scalar> rotateLeft(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
29+
{
30+
complex_t<Scalar> retVal = { -value.imag(), value.real() };
31+
return retVal;
32+
}
33+
34+
// Fast mul by -i
35+
template<typename Scalar>
36+
complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
37+
{
38+
complex_t<Scalar> retVal = { value.imag(), -value.real() };
39+
return retVal;
40+
}
41+
42+
}
43+
}
44+
45+
// -------------------------------------- END CPP VERSION ------------------------------------
46+
47+
// -------------------------------------- HLSL VERSION ---------------------------------------
48+
#else
1049

1150
namespace nbl
1251
{
@@ -126,6 +165,8 @@ struct complex_t
126165
template<typename Scalar>
127166
struct plus< complex_t<Scalar> >
128167
{
168+
using type_t = complex_t<Scalar>;
169+
129170
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
130171
{
131172
return lhs + rhs;
@@ -137,6 +178,8 @@ struct plus< complex_t<Scalar> >
137178
template<typename Scalar>
138179
struct minus< complex_t<Scalar> >
139180
{
181+
using type_t = complex_t<Scalar>;
182+
140183
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
141184
{
142185
return lhs - rhs;
@@ -148,6 +191,8 @@ struct minus< complex_t<Scalar> >
148191
template<typename Scalar>
149192
struct multiplies< complex_t<Scalar> >
150193
{
194+
using type_t = complex_t<Scalar>;
195+
151196
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
152197
{
153198
return lhs * rhs;
@@ -164,6 +209,8 @@ struct multiplies< complex_t<Scalar> >
164209
template<typename Scalar>
165210
struct divides< complex_t<Scalar> >
166211
{
212+
using type_t = complex_t<Scalar>;
213+
167214
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
168215
{
169216
return lhs / rhs;
@@ -379,6 +426,22 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
379426
return retVal;
380427
}
381428

429+
template<typename Scalar>
430+
struct ternary_operator< complex_t<Scalar> >
431+
{
432+
using type_t = complex_t<Scalar>;
433+
434+
complex_t<Scalar> operator()(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
435+
{
436+
const vector<Scalar, 2> lhsVector = vector<Scalar, 2>(lhs.real(), lhs.imag());
437+
const vector<Scalar, 2> rhsVector = vector<Scalar, 2>(rhs.real(), rhs.imag());
438+
const vector<Scalar, 2> resultVector = condition ? lhsVector : rhsVector;
439+
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
440+
return result;
441+
}
442+
};
443+
444+
382445
}
383446
}
384447

@@ -396,4 +459,7 @@ NBL_REGISTER_OBJ_TYPE(complex_t<float64_t2>,::nbl::hlsl::alignment_of_v<float64_
396459
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t3>,::nbl::hlsl::alignment_of_v<float64_t3>)
397460
NBL_REGISTER_OBJ_TYPE(complex_t<float64_t4>,::nbl::hlsl::alignment_of_v<float64_t4>)
398461

462+
// -------------------------------------- END HLSL VERSION ---------------------------------------
399463
#endif
464+
465+
#endif
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_FFT_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/concepts.hlsl"
5+
#include "nbl/builtin/hlsl/fft/common.hlsl"
6+
7+
namespace nbl
8+
{
9+
namespace hlsl
10+
{
11+
namespace workgroup
12+
{
13+
namespace fft
14+
{
15+
// The SharedMemoryAccessor MUST provide the following methods:
16+
// * void get(uint32_t index, inout uint32_t value);
17+
// * void set(uint32_t index, in uint32_t value);
18+
// * void workgroupExecutionAndMemoryBarrier();
19+
20+
#define NBL_CONCEPT_NAME FFTSharedMemoryAccessor
21+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
22+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
23+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
24+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
25+
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
26+
NBL_CONCEPT_BEGIN(3)
27+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
28+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
29+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
30+
NBL_CONCEPT_END(
31+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t, uint32_t>(index, val)), is_same_v, void))
32+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t, uint32_t>(index, val)), is_same_v, void))
33+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
34+
);
35+
#undef val
36+
#undef index
37+
#undef accessor
38+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
39+
40+
41+
// The Accessor (for a small FFT) MUST provide the following methods:
42+
// * void get(uint32_t index, inout complex_t<Scalar> value);
43+
// * void set(uint32_t index, in complex_t<Scalar> value);
44+
45+
#define NBL_CONCEPT_NAME SmallFFTAccessor
46+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
47+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
48+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
49+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
50+
#define NBL_CONCEPT_PARAM_2 (val, complex_t<Scalar>)
51+
NBL_CONCEPT_BEGIN(3)
52+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
53+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
54+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
55+
NBL_CONCEPT_END(
56+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.set(index, val)), is_same_v, void))
57+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.get(index, val)), is_same_v, void))
58+
);
59+
#undef val
60+
#undef index
61+
#undef accessor
62+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
63+
64+
65+
// The Accessor MUST provide the following methods:
66+
// * void get(uint32_t index, inout complex_t<Scalar> value);
67+
// * void set(uint32_t index, in complex_t<Scalar> value);
68+
// * void memoryBarrier();
69+
70+
#define NBL_CONCEPT_NAME FFTAccessor
71+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
72+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(Scalar)
73+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
74+
NBL_CONCEPT_BEGIN(1)
75+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
76+
NBL_CONCEPT_END(
77+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.memoryBarrier()), is_same_v, void))
78+
) && SmallFFTAccessor<T, Scalar>;
79+
#undef accessor
80+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
81+
82+
}
83+
}
84+
}
85+
}
86+
87+
#endif

include/nbl/builtin/hlsl/cpp_compat/basic.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ inline To _static_cast(From v)
4141
#define NBL_CONSTEXPR_STATIC constexpr static
4242
#define NBL_CONSTEXPR_STATIC_INLINE constexpr static inline
4343
#define NBL_CONSTEXPR_INLINE_FUNC constexpr inline
44+
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC NBL_FORCE_INLINE constexpr
4445
#define NBL_CONST_MEMBER_FUNC const
4546

4647
namespace nbl::hlsl
@@ -70,6 +71,7 @@ namespace nbl::hlsl
7071
#define NBL_CONSTEXPR_STATIC const static
7172
#define NBL_CONSTEXPR_STATIC_INLINE const static
7273
#define NBL_CONSTEXPR_INLINE_FUNC inline
74+
#define NBL_CONSTEXPR_FORCED_INLINE_FUNC inline
7375
#define NBL_CONST_MEMBER_FUNC
7476

7577
namespace nbl

0 commit comments

Comments
 (0)