Skip to content

Commit 8d92ed2

Browse files
committed
Fixed matrix multiplication
1 parent 2b1701a commit 8d92ed2

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

include/nbl/builtin/hlsl/emulated/matrix_t.hlsl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define _NBL_BUILTIN_HLSL_EMULATED_MATRIX_T_HLSL_INCLUDED_
33

44
#include <nbl/builtin/hlsl/portable/float64_t.hlsl>
5+
#include <nbl/builtin/hlsl/emulated/vector_t.hlsl>
56
#include <nbl/builtin/hlsl/matrix_utils/matrix_traits.hlsl>
67

78
namespace nbl
@@ -63,9 +64,14 @@ struct matrix_traits<emulated_matrix<T, ROW_COUNT, COLUMN_COUNT> > \
6364
};
6465

6566
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(2, 2)
67+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(2, 3)
68+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(2, 4)
69+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(3, 2)
6670
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(3, 3)
67-
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(4, 4)
6871
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(3, 4)
72+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(4, 2)
73+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(4, 3)
74+
DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION(4, 4)
6975

7076
#undef DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION
7177

@@ -91,18 +97,18 @@ struct mul_helper<emulated_matrix<ComponentT, N, M>, emulated_matrix<ComponentT,
9197

9298
static inline return_t __call(LhsT lhs, RhsT rhs)
9399
{
94-
typename matrix_traits<RhsT>::transposed_type rhsTransposed = rhs.getTransposed();
95-
const uint32_t outputRowCount = matrix_traits<return_t>::RowCount;
96-
const uint32_t outputColumnCount = matrix_traits<return_t>::ColumnCount;
97100
using OutputVecType = typename matrix_traits<return_t>::row_type;
101+
const uint32_t outputRowCount = vector_traits<OutputVecType>::Dimension;
98102

99-
nbl::hlsl::array_set<OutputVecType, typename vector_traits<OutputVecType>::scalar_type> setter;
103+
nbl::hlsl::array_get<typename matrix_traits<LhsT>::row_type, typename vector_traits<typename matrix_traits<LhsT>::row_type>::scalar_type> getter;
100104

101105
return_t output;
102-
for (int r = 0; r < outputRowCount; ++r)
106+
const uint32_t RHSRowCount = matrix_traits<RhsT>::RowCount;
107+
for (uint32_t rO = 0; rO < outputRowCount; ++rO)
103108
{
104-
for (int c = 0; c < outputColumnCount; ++c)
105-
setter(output.rows[r], c, dot<OutputVecType>(lhs.rows[r], rhsTransposed.rows[c]));
109+
output.rows[rO] = rhs.rows[0] * getter(lhs.rows[rO], 0);
110+
for (uint32_t rI = 1; rI < RHSRowCount; ++rI) // its also the LHS column count
111+
output.rows[rO] = output.rows[rO] + rhs.rows[rI] * getter(lhs.rows[rO], rI);
106112
}
107113

108114
return output;
@@ -132,4 +138,4 @@ struct mul_helper<emulated_matrix<ComponentT, RowCount, ColumnCount>, emulated_v
132138

133139
}
134140
}
135-
#endif
141+
#endif

0 commit comments

Comments
 (0)