Skip to content

Commit fa02163

Browse files
author
devsh
committed
implement fast_affine.hlsl
1 parent bb6f9e5 commit fa02163

File tree

1 file changed

+151
-43
lines changed

1 file changed

+151
-43
lines changed

include/nbl/builtin/hlsl/math/linalg/fast_affine.hlsl

Lines changed: 151 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,70 +4,178 @@
44
#ifndef _NBL_BUILTIN_HLSL_MATH_LINALG_FAST_AFFINE_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_MATH_LINALG_FAST_AFFINE_INCLUDED_
66

7+
8+
#include <nbl/builtin/hlsl/mpl.hlsl>
9+
#include <nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl>
710
#include <nbl/builtin/hlsl/concepts.hlsl>
811

9-
#if 0 // TODO
10-
vec4 pseudoMul4x4with3x1(in mat4 m, in vec3 v)
12+
13+
namespace nbl
1114
{
12-
return m[0] * v.x + m[1] * v.y + m[2] * v.z + m[3];
13-
}
14-
vec3 pseudoMul3x4with3x1(in mat4x3 m, in vec3 v)
15+
namespace hlsl
1516
{
16-
return m[0] * v.x + m[1] * v.y + m[2] * v.z + m[3];
17-
}
18-
mat4x3 pseudoMul4x3with4x3(in mat4x3 lhs, in mat4x3 rhs) // TODO: change name to 3x4with3x4
17+
namespace math
1918
{
20-
mat4x3 result;
21-
for (int i = 0; i < 4; i++)
22-
result[i] = lhs[0] * rhs[i][0] + lhs[1] * rhs[i][1] + lhs[2] * rhs[i][2];
23-
result[3] += lhs[3];
24-
return result;
25-
}
26-
mat4 pseudoMul4x4with4x3(in mat4 proj, in mat4x3 tform)
19+
namespace linalg
2720
{
28-
mat4 result;
29-
for (int i = 0; i < 4; i++)
30-
result[i] = proj[0] * tform[i][0] + proj[1] * tform[i][1] + proj[2] * tform[i][2];
31-
result[3] += proj[3];
32-
return result;
33-
}
21+
// TODO: move to macros
22+
#ifdef __HLSL_VERSION
23+
#define NBL_UNROLL [[unroll]]
24+
#else
25+
#define NBL_UNROLL
26+
#endif
3427

35-
// useful for fast computation of a Normal Matrix (you just need to remember to normalize the transformed normal because of the missing divide by the determinant)
36-
mat3 sub3x3TransposeCofactors(in mat3 sub3x3)
28+
// Multiply matrices as-if extended to be filled with identity elements
29+
template<typename T, int N, int M, int P, int Q>
30+
matrix<T,N,M> promoted_mul(NBL_CONST_REF_ARG(matrix<T,N,P>) lhs, NBL_CONST_REF_ARG(matrix<T,Q,M>) rhs)
3731
{
38-
return mat3(
39-
cross(sub3x3[1],sub3x3[2]),
40-
cross(sub3x3[2],sub3x3[0]),
41-
cross(sub3x3[0],sub3x3[1])
42-
);
32+
matrix<T,N,M> retval;
33+
// NxM = NxR RxM
34+
// out[i][j] == dot(row[i],col[j])
35+
// out[i][j] == lhs[i][0]*col[j][0]+...+lhs[i][3]*col[j][3]
36+
// col[a][b] == (rhs^T)[b][a]
37+
// out[i][j] == lhs[i][0]*rhs[0][j]+...+lhs[i][3]*rhs[3][j]
38+
// out[i] == lhs[i][0]*rhs[0]+...+lhs[i][3]*rhs[3]
39+
NBL_UNROLL for (uint32_t i=0; i<N; i++)
40+
{
41+
vector<T,M> acc = rhs[i];
42+
// multiply if not outside of `lhs` matrix
43+
// otherwise the diagonal element is just unity
44+
if (i<P)
45+
acc *= lhs[i][i];
46+
// other elements are 0 if outside the LHS matrix
47+
NBL_UNROLL for (uint32_t j=0; j<P; j++)
48+
if (j!=i)
49+
{
50+
// inside the RHS matrix
51+
if (j<Q)
52+
acc += rhs[j]*lhs[i][j];
53+
else // outside we have an implicit e_j valued row
54+
acc[j] += lhs[i][j];
55+
}
56+
retval[i] = acc;
57+
}
58+
return retval;
4359
}
44-
// returns a signflip mask
45-
uint sub3x3TransposeCofactors(in mat3 sub3x3, out mat3 sub3x3TransposeCofactors)
60+
61+
// Multiply matrix and vector as-if extended to be filled with 1 in diagonal for matrix and last for vector
62+
template<typename T, int N, int M, int P>
63+
vector<T,N> promoted_mul(NBL_CONST_REF_ARG(matrix<T,N,M>) lhs, const vector<T,P> v)
4664
{
47-
sub3x3TransposeCofactors = sub3x3TransposeCofactors(sub3x3);
48-
return floatBitsToUint(dot(sub3x3[0],sub3x3TransposeCofactors[0]))&0x80000000u;
65+
vector<T,N> retval;
66+
// Nx1 = NxM Mx1
67+
{
68+
matrix<T,M,1> rhs;
69+
// one can safely discard elements of `v[i]` where `i<P && i>=M`, because to contribute `lhs` would need to have `M>=P`
70+
NBL_UNROLL for (uint32_t i=0; i<M; i++)
71+
{
72+
if (i<P)
73+
rhs[i] = v[i];
74+
else
75+
rhs[i] = i!=(M-1) ? T(0):T(1);
76+
}
77+
matrix<T,N,1> tmp = promoted_mul<T,N,1,M,M>(lhs,rhs);
78+
NBL_UNROLL for (uint32_t i=0; i<N; i++)
79+
retval[i] = tmp[i];
80+
}
81+
return retval;
4982
}
83+
#undef NBL_UNROLL
84+
85+
// useful for fast computation of a Normal Matrix
86+
template<typename T, int N>
87+
struct cofactors_base;
5088

51-
// use this if you anticipate flipped/mirrored models
52-
vec3 fastNormalTransform(in uint signFlipMask, in mat3 sub3x3TransposeCofactors, in vec3 normal)
89+
template<typename T>
90+
struct cofactors_base<T,3>
5391
{
54-
vec3 tmp = sub3x3TransposeCofactors*normal;
55-
const float tmpLenRcp = inversesqrt(dot(tmp,tmp));
56-
return tmp*uintBitsToFloat(floatBitsToUint(tmpLenRcp)^signFlipMask);
57-
}
58-
#endif
92+
using matrix_t = matrix<T,3,3>;
93+
using vector_t = vector<T,3>;
94+
95+
static inline cofactors_base<T,3> create(NBL_CONST_REF_ARG(matrix_t) val)
96+
{
97+
cofactors_base<T,3> retval;
98+
99+
retval.transposed = matrix_t(
100+
hlsl::cross<vector_t>(val[1],val[2]),
101+
hlsl::cross<vector_t>(val[2],val[0]),
102+
hlsl::cross<vector_t>(val[0],val[1])
103+
);
104+
105+
return retval;
106+
}
107+
108+
//
109+
inline matrix_t get() NBL_CONST_MEMBER_FUNC
110+
{
111+
return hlsl::transpose<matrix_t>(transposed);
112+
}
113+
114+
//
115+
inline vector_t normalTransform(const vector_t n) NBL_CONST_MEMBER_FUNC
116+
{
117+
const vector_t tmp = hlsl::mul<matrix_t,vector_t>(transposed,n);
118+
return hlsl::normalize<vector_t>(tmp);
119+
}
120+
121+
matrix_t transposed;
122+
};
123+
124+
// variant that cares about flipped/mirrored transforms
125+
template<typename T, int N>
126+
struct cofactors
127+
{
128+
using pseudo_base_t = cofactors_base<T,N>;
129+
using matrix_t = pseudo_base_t::matrix_t;
130+
using vector_t = pseudo_base_t::vector_t;
131+
using mask_t = unsigned_integer_of_size_t<sizeof(T)>;
132+
133+
static inline cofactors<T,3> create(NBL_CONST_REF_ARG(matrix_t) val)
134+
{
135+
cofactors<T,3> retval;
136+
retval.composed = pseudo_base_t::create(val);
137+
138+
const T det = hlsl::dot<vector_t>(val[0],retval.composed.transposed[0]);
139+
140+
const mask_t SignBit = 1;
141+
SignBit = SignBit<<(sizeof(mask_t)*8-1);
142+
retval.signFlipMask = bit_cast<mask_t>(det) & SignBit;
143+
144+
return retval;
145+
}
146+
147+
//
148+
inline vector_t normalTransform(const vector_t n) NBL_CONST_MEMBER_FUNC
149+
{
150+
const vector_t tmp = hlsl::mul<matrix_t,vector_t>(composed.transposed,n);
151+
const T rcpLen = hlsl::rsqrt<T>(hlsl::dot<vector_t>(tmp,tmp));
152+
return tmp*bit_cast<T>(bit_cast<mask_t>(rcpLen)^determinantSignMask);
153+
}
154+
155+
cofactors_base<T,N> composed;
156+
mask_t determinantSignMask;
157+
};
59158

60159
//
61-
template<typename Mat3x4> NBL_REQUIRES(is_matrix_v<Mat3x4>) // TODO: allow any matrix type AND our emulated ones
62-
Mat3x4 pseudoInverse3x4(NBL_CONST_REF_ARG(Mat3x4) tform)
160+
template<typename Mat3x4 NBL_FUNC_REQUIRES(is_matrix_v<Mat3x4>) // TODO: allow any matrix type AND our emulated ones
161+
Mat3x4 pseudoInverse3x4(NBL_CONST_REF_ARG(Mat3x4) tform, NBL_CONST_REF_ARG(matrix<scalar_type_t<Mat3x4>,3,3>) sub3x3Inv)
63162
{
64-
const matrix<scalar_type_t<Mat3x4>,3,3> sub3x3Inv = inverse(mat3(tform));
65163
Mat3x4 retval;
66164
retval[0] = sub3x3Inv[0];
67165
retval[1] = sub3x3Inv[1];
68166
retval[2] = sub3x3Inv[2];
69-
retval[3] = -sub3x3Inv*tform[3];
167+
retval[3] = -hlsl::mul(sub3x3Inv,tform[3]);
70168
return retval;
71169
}
170+
template<typename Mat3x4 NBL_FUNC_REQUIRES(is_matrix_v<Mat3x4>) // TODO: allow any matrix type AND our emulated ones
171+
Mat3x4 pseudoInverse3x4(NBL_CONST_REF_ARG(Mat3x4) tform)
172+
{
173+
return pseudoInverse3x4(tform,inverse(matrix<scalar_type_t<Mat3x4>,3,3>(tform)));
174+
}
175+
72176

177+
}
178+
}
179+
}
180+
}
73181
#endif

0 commit comments

Comments
 (0)