Skip to content

Commit 0ec642a

Browse files
author
devsh
committed
add octahedral.hlsl
1 parent 563de02 commit 0ec642a

File tree

4 files changed

+321
-1
lines changed

4 files changed

+321
-1
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#ifndef _NBL_HLSL_FORMAT_OCTAHEDRAL_HLSL_
2+
#define _NBL_HLSL_FORMAT_OCTAHEDRAL_HLSL_
3+
4+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+
#include "nbl/builtin/hlsl/type_traits.hlsl"
6+
#include "nbl/builtin/hlsl/limits.hlsl"
7+
8+
namespace nbl
9+
{
10+
namespace hlsl
11+
{
12+
namespace format
13+
{
14+
15+
template<typename UintT, uint16_t Bits=sizeof(UintT)*4>
16+
struct octahedral// : enable_if_t<Bits*2>sizeof(UintT)||Bits*4<sizeof(UintT)> need a way to static_assert in SPIRV!
17+
{
18+
using this_t = octahedral<UintT,Bits>;
19+
using storage_t = UintT;
20+
21+
NBL_CONSTEXPR_STATIC_INLINE uint16_t BitsUsed = Bits;
22+
23+
bool operator==(const this_t other)
24+
{
25+
return storage==other.storage;
26+
}
27+
bool operator!=(const this_t other)
28+
{
29+
return storage==other.storage;
30+
}
31+
32+
storage_t storage;
33+
};
34+
35+
}
36+
37+
// https://www.shadertoy.com/view/Mtfyzl
38+
namespace impl
39+
{
40+
// TODO: remove after the `emulated_float` merge
41+
template<typename T, typename U>
42+
struct _static_cast_helper;
43+
44+
// decode
45+
template<typename float_t, typename UintT, uint16_t Bits>
46+
struct _static_cast_helper<vector<float_t,3>,format::octahedral<UintT,Bits> >
47+
{
48+
using T = vector<float_t,3>;
49+
using U = format::octahedral<UintT,Bits>;
50+
51+
T operator()(U val)
52+
{
53+
using storage_t = typename U::storage_t;
54+
const storage_t MaxVal = (storage_t(1)<<U::BitsUsed)-1u;
55+
56+
// NOTE: We Assume the top unused bits are clean!
57+
const vector<float_t,2> v = vector<float_t,2>(val.storage&MaxVal,val.storage>>U::BitsUsed) / (vector<float_t,2>(MaxVal,MaxVal)*0.5) - vector<float_t,2>(1,1);
58+
59+
// Rune Stubbe's version, much faster than original
60+
vector<float_t,3> nor = vector<float_t,3>(v,float_t(1)-abs(v.x)-abs(v.y));
61+
const float_t t = max(-nor.z,float_t(0));
62+
// TODO: improve the copysign with `^` and a sign mask
63+
nor.x += (nor.x>0.0) ? -t:t;
64+
nor.y += (nor.y>0.0) ? -t:t;
65+
66+
return normalize(nor);
67+
}
68+
};
69+
// encode
70+
template<typename UintT, uint16_t Bits, typename float_t>
71+
struct _static_cast_helper<format::octahedral<UintT,Bits>,vector<float_t,3> >
72+
{
73+
using T = format::octahedral<UintT,Bits>;
74+
using U = vector<float_t,3>;
75+
76+
T operator()(U nor)
77+
{
78+
nor /= (abs(nor.x) + abs(nor.y) + abs(nor.z));
79+
if (nor.z<float_t(0)) // TODO: faster sign copy
80+
nor.xy = (float_t(1)-abs(nor.yx))*sign(nor.xy);
81+
82+
vector<float_t,2> v = nor.xy*float_t(0.5)+vector<float_t,2>(0.5,0.5);
83+
84+
using storage_t = typename T::storage_t;
85+
const storage_t MaxVal = (storage_t(1)<<T::BitsUsed)-1u;
86+
const vector<storage_t,2> d = vector<storage_t,2>(v*float_t(MaxVal)+vector<float_t,2>(0.5,0.5));
87+
88+
T retval;
89+
retval.storage = (d.y<<T::BitsUsed)|d.x;
90+
return retval;
91+
}
92+
};
93+
}
94+
95+
}
96+
}
97+
#endif
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#ifndef _NBL_HLSL_FORMAT_SHARED_EXP_HLSL_
2+
#define _NBL_HLSL_FORMAT_SHARED_EXP_HLSL_
3+
4+
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+
#include "nbl/builtin/hlsl/type_traits.hlsl"
6+
#include "nbl/builtin/hlsl/limits.hlsl"
7+
8+
namespace nbl
9+
{
10+
namespace hlsl
11+
{
12+
13+
namespace format
14+
{
15+
16+
template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
17+
struct shared_exp// : enable_if_t<_ExponentBits<16> need a way to static_assert in SPIRV!
18+
{
19+
using this_t = shared_exp<IntT,_Components,_ExponentBits>;
20+
using storage_t = typename make_unsigned<IntT>::type;
21+
NBL_CONSTEXPR_STATIC_INLINE uint16_t Components = _Components;
22+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentBits = _ExponentBits;
23+
24+
// Not even going to consider fp16 and fp64 dependence on device traits
25+
using decode_t = float32_t;
26+
27+
bool operator==(const this_t other)
28+
{
29+
return storage==other.storage;
30+
}
31+
bool operator!=(const this_t other)
32+
{
33+
return storage==other.storage;
34+
}
35+
36+
storage_t storage;
37+
};
38+
39+
// all of this because DXC has bugs in partial template spec
40+
namespace impl
41+
{
42+
template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
43+
struct numeric_limits_shared_exp
44+
{
45+
using type = format::shared_exp<IntT,_Components,_ExponentBits>;
46+
using value_type = typename type::decode_t;
47+
using __storage_t = typename type::storage_t;
48+
49+
NBL_CONSTEXPR_STATIC_INLINE bool is_specialized = true;
50+
NBL_CONSTEXPR_STATIC_INLINE bool is_signed = is_signed_v<IntT>;
51+
NBL_CONSTEXPR_STATIC_INLINE bool is_integer = false;
52+
NBL_CONSTEXPR_STATIC_INLINE bool is_exact = false;
53+
// infinity and NaN are not representable in shared exponent formats
54+
NBL_CONSTEXPR_STATIC_INLINE bool has_infinity = false;
55+
NBL_CONSTEXPR_STATIC_INLINE bool has_quiet_NaN = false;
56+
NBL_CONSTEXPR_STATIC_INLINE bool has_signaling_NaN = false;
57+
// shared exponent formats have no leading 1 in the mantissa, therefore denormalized values aren't really a concept, although one can argue all values are denorm then?
58+
NBL_CONSTEXPR_STATIC_INLINE bool has_denorm = false;
59+
NBL_CONSTEXPR_STATIC_INLINE bool has_denorm_loss = false;
60+
// truncation
61+
// NBL_CONSTEXPR_STATIC_INLINE float_round_style round_style = round_to_nearest;
62+
NBL_CONSTEXPR_STATIC_INLINE bool is_iec559 = false;
63+
NBL_CONSTEXPR_STATIC_INLINE bool is_bounded = true;
64+
NBL_CONSTEXPR_STATIC_INLINE bool is_modulo = false;
65+
NBL_CONSTEXPR_STATIC_INLINE int32_t digits = (sizeof(IntT)*8-(is_signed ? _Components:0)-_ExponentBits)/_Components;
66+
NBL_CONSTEXPR_STATIC_INLINE int32_t radix = 2;
67+
NBL_CONSTEXPR_STATIC_INLINE int32_t max_exponent = 1<<(_ExponentBits-1);
68+
NBL_CONSTEXPR_STATIC_INLINE int32_t min_exponent = 1-max_exponent;
69+
NBL_CONSTEXPR_STATIC_INLINE bool traps = false;
70+
71+
// extras
72+
NBL_CONSTEXPR_STATIC_INLINE __storage_t MantissaMask = ((__storage_t(1))<<digits)-__storage_t(1);
73+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentBits = _ExponentBits;
74+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentMask = uint16_t((1<<_ExponentBits)-1);
75+
76+
// TODO: functions done as vars
77+
// NBL_CONSTEXPR_STATIC_INLINE value_type min = base::min();
78+
// shift down by 1 to get rid of explicit 1 in mantissa that is now implicit, then +1 in the exponent to compensate
79+
NBL_CONSTEXPR_STATIC_INLINE __storage_t max =
80+
((max_exponent+1-numeric_limits<value_type>::min_exponent)<<(numeric_limits<value_type>::digits-1))|
81+
((MantissaMask>>1)<<(numeric_limits<value_type>::digits-digits));
82+
NBL_CONSTEXPR_STATIC_INLINE __storage_t lowest = is_signed ? ((__storage_t(1)<<(sizeof(__storage_t)*8-1))|max):__storage_t(0);
83+
/*
84+
NBL_CONSTEXPR_STATIC_INLINE value_type epsilon = base::epsilon();
85+
NBL_CONSTEXPR_STATIC_INLINE value_type round_error = base::round_error();
86+
*/
87+
};
88+
}
89+
90+
}
91+
92+
// specialize the limits
93+
template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
94+
struct numeric_limits<format::shared_exp<IntT,_Components,_ExponentBits> > : format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>
95+
{
96+
};
97+
98+
namespace impl
99+
{
100+
// TODO: remove after the `emulated_float` merge
101+
template<typename T, typename U>
102+
struct _static_cast_helper;
103+
104+
// TODO: versions for `float16_t`
105+
106+
// decode
107+
template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
108+
struct _static_cast_helper<
109+
vector<typename format::shared_exp<IntT,_Components,_ExponentBits>::decode_t,_Components>,
110+
format::shared_exp<IntT,_Components,_ExponentBits>
111+
>
112+
{
113+
using U = format::shared_exp<IntT,_Components,_ExponentBits>;
114+
using T = vector<typename U::decode_t,_Components>;
115+
116+
T operator()(U val)
117+
{
118+
using storage_t = typename U::storage_t;
119+
// DXC error: error: expression class 'DependentScopeDeclRefExpr' unimplemented, doesn't matter as decode_t is always float32_t for now
120+
//using decode_t = typename T::decode_t;
121+
using decode_t = float32_t;
122+
// no clue why the compiler doesn't pick up the partial specialization and tries to use the general one
123+
using limits_t = format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>;
124+
125+
T retval;
126+
for (uint16_t i=0; i<_Components; i++)
127+
retval[i] = decode_t((val.storage>>storage_t(limits_t::digits*i))&limits_t::MantissaMask);
128+
uint16_t exponent = uint16_t(val.storage>>storage_t(limits_t::digits*3));
129+
if (limits_t::is_signed)
130+
{
131+
for (uint16_t i=0; i<_Components; i++)
132+
if (exponent&(uint16_t(1)<<(_ExponentBits+i)))
133+
retval[i] = -retval[i];
134+
exponent &= limits_t::ExponentMask;
135+
}
136+
return retval*exp2(int32_t(exponent-limits_t::digits)+limits_t::min_exponent);
137+
}
138+
};
139+
// encode (WARNING DOES NOT CHECK THAT INPUT IS IN THE RANGE!)
140+
template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
141+
struct _static_cast_helper<
142+
format::shared_exp<IntT,_Components,_ExponentBits>,
143+
vector<typename format::shared_exp<IntT,_Components,_ExponentBits>::decode_t,_Components>
144+
>
145+
{
146+
using T = format::shared_exp<IntT,_Components,_ExponentBits>;
147+
using U = vector<typename T::decode_t,_Components>;
148+
149+
T operator()(U val)
150+
{
151+
using storage_t = typename T::storage_t;
152+
// DXC error: error: expression class 'DependentScopeDeclRefExpr' unimplemented, doesn't matter as decode_t is always float32_t for now
153+
//using decode_t = typename T::decode_t;
154+
using decode_t = float32_t;
155+
//
156+
using decode_bits_t = unsigned_integer_of_size<sizeof(decode_t)>::type;
157+
// no clue why the compiler doesn't pick up the partial specialization and tries to use the general one
158+
using limits_t = format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>;
159+
160+
// get exponents
161+
vector<uint16_t,_Components> exponentsDecBias;
162+
const int32_t dec_MantissaStoredBits = numeric_limits<decode_t>::digits-1;
163+
for (uint16_t i=0; i<_Components; i++)
164+
{
165+
decode_t v = val[i];
166+
if (limits_t::is_signed)
167+
v = abs(v);
168+
exponentsDecBias[i] = uint16_t(asuint(v)>>dec_MantissaStoredBits);
169+
}
170+
171+
// get the maximum exponent
172+
uint16_t sharedExponentDecBias = exponentsDecBias[0];
173+
for (uint16_t i=1; i<_Components; i++)
174+
sharedExponentDecBias = max(exponentsDecBias[i],sharedExponentDecBias);
175+
176+
// NOTE: we don't consider clamping against `limits_t::max_exponent`, should be ensured by clamping the inputs against `limits_t::max` before casting!
177+
178+
// we need to stop "shifting up" implicit leading 1. to farthest left position if the exponent too small
179+
uint16_t clampedSharedExponentDecBias;
180+
if (limits_t::min_exponent>numeric_limits<decode_t>::min_exponent) // if ofc its needed at all
181+
clampedSharedExponentDecBias = max(sharedExponentDecBias,uint16_t(limits_t::min_exponent-numeric_limits<decode_t>::min_exponent));
182+
else
183+
clampedSharedExponentDecBias = sharedExponentDecBias;
184+
185+
// we always shift down, the question is how much
186+
vector<uint16_t,_Components> mantissaShifts;
187+
for (uint16_t i=0; i<_Components; i++)
188+
mantissaShifts[i] = min(clampedSharedExponentDecBias+uint16_t(-limits_t::min_exponent)-exponentsDecBias[i],uint16_t(numeric_limits<decode_t>::digits));
189+
190+
// finally lets re-bias our exponent (it will always be positive), note the -1 because IEEE754 floats reserve the lowest exponent values for denorm
191+
const uint16_t sharedExponentEncBias = int16_t(clampedSharedExponentDecBias+int16_t(-limits_t::min_exponent))-uint16_t(1-numeric_limits<decode_t>::min_exponent);
192+
193+
//
194+
T retval;
195+
retval.storage = storage_t(sharedExponentEncBias)<<(limits_t::digits*3);
196+
const decode_bits_t dec_MantissaMask = (decode_bits_t(1)<<dec_MantissaStoredBits)-1;
197+
for (uint16_t i=0; i<_Components; i++)
198+
{
199+
decode_bits_t origBitPattern = bit_cast<decode_bits_t>(val[i])&dec_MantissaMask;
200+
// put the implicit 1 in (don't care about denormalized because its probably less than our `limits_t::min` (TODO: static assert it)
201+
origBitPattern |= decode_bits_t(1)<<dec_MantissaStoredBits;
202+
// shift and put in the right place
203+
retval.storage |= storage_t(origBitPattern>>mantissaShifts[i])<<(limits_t::digits*i);
204+
}
205+
if (limits_t::is_signed)
206+
{
207+
// doing ops on smaller integers is faster
208+
decode_bits_t SignMask = 0x1<<(sizeof(decode_t)*8-1);
209+
decode_bits_t signs = bit_cast<decode_bits_t>(val[0])&SignMask;
210+
for (uint16_t i=1; i<_Components; i++)
211+
signs |= (bit_cast<decode_bits_t>(val[i])&SignMask)>>i;
212+
retval.storage |= storage_t(signs)<<((sizeof(storage_t)-sizeof(decode_t))*8);
213+
}
214+
return retval;
215+
}
216+
};
217+
}
218+
}
219+
}
220+
#endif

src/nbl/builtin/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/device_capabilities_traits.hl
269269
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/numbers.hlsl")
270270
#Complex math
271271
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/complex.hlsl")
272+
# format
273+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/format/octahedral.hlsl")
274+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/format/shared_exp.hlsl")
272275
#linear algebra
273276
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/math/linalg/fast_affine.hlsl")
274277
# TODO: rename `equations` to `polynomials` probably

0 commit comments

Comments
 (0)