2
2
#define _NBL_BUILTIN_HLSL_EMULATED_MATRIX_T_HLSL_INCLUDED_
3
3
4
4
#include <nbl/builtin/hlsl/portable/float64_t.hlsl>
5
+ #include <nbl/builtin/hlsl/emulated/vector_t.hlsl>
5
6
#include <nbl/builtin/hlsl/matrix_utils/matrix_traits.hlsl>
6
7
7
8
namespace nbl
@@ -63,9 +64,14 @@ struct matrix_traits<emulated_matrix<T, ROW_COUNT, COLUMN_COUNT> > \
63
64
};
64
65
65
66
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 2 )
67
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 3 )
68
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 4 )
69
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 2 )
66
70
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 3 )
67
- DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 4 )
68
71
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 4 )
72
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 2 )
73
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 3 )
74
+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 4 )
69
75
70
76
#undef DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION
71
77
@@ -91,18 +97,18 @@ struct mul_helper<emulated_matrix<ComponentT, N, M>, emulated_matrix<ComponentT,
91
97
92
98
static inline return_t __call (LhsT lhs, RhsT rhs)
93
99
{
94
- typename matrix_traits<RhsT>::transposed_type rhsTransposed = rhs.getTransposed ();
95
- const uint32_t outputRowCount = matrix_traits<return_t>::RowCount;
96
- const uint32_t outputColumnCount = matrix_traits<return_t>::ColumnCount;
97
100
using OutputVecType = typename matrix_traits<return_t>::row_type;
101
+ const uint32_t outputRowCount = vector_traits<OutputVecType>::Dimension;
98
102
99
- nbl::hlsl::array_set<OutputVecType , typename vector_traits<OutputVecType >::scalar_type> setter ;
103
+ nbl::hlsl::array_get<typename matrix_traits<LhsT>::row_type , typename vector_traits<typename matrix_traits<LhsT >::row_type>:: scalar_type> getter ;
100
104
101
105
return_t output;
102
- for (int r = 0 ; r < outputRowCount; ++r)
106
+ const uint32_t RHSRowCount = matrix_traits<RhsT>::RowCount;
107
+ for (uint32_t rO = 0 ; rO < outputRowCount; ++rO)
103
108
{
104
- for (int c = 0 ; c < outputColumnCount; ++c)
105
- setter (output.rows[r], c, dot<OutputVecType>(lhs.rows[r], rhsTransposed.rows[c]));
109
+ output.rows[rO] = rhs.rows[0 ] * getter (lhs.rows[rO], 0 );
110
+ for (uint32_t rI = 1 ; rI < RHSRowCount; ++rI) // its also the LHS column count
111
+ output.rows[rO] = output.rows[rO] + rhs.rows[rI] * getter (lhs.rows[rO], rI);
106
112
}
107
113
108
114
return output;
@@ -132,4 +138,4 @@ struct mul_helper<emulated_matrix<ComponentT, RowCount, ColumnCount>, emulated_v
132
138
133
139
}
134
140
}
135
- #endif
141
+ #endif
0 commit comments