@@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in
48
48
49
49
FORCEINLINE void vec_load_mult12a_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 inp )
50
50
{
51
- vec_bf16 in01 = (vec_bf16 )vec_load_vec (in0 );
52
51
vec_bf16 in11 = (vec_bf16 )vec_load_vec (in1 );
53
52
54
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
53
+ vec_load_mult_mma (out , in0 , inp );
54
+
55
55
__builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
56
56
}
57
57
58
58
FORCEINLINE void vec_load_mult14_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 inp )
59
59
{
60
- vec_bf16 in01 = (vec_bf16 )vec_load_vec (in0 );
61
- vec_bf16 in11 = (vec_bf16 )vec_load_vec (in1 );
62
60
vec_bf16 in21 = (vec_bf16 )vec_load_vec (in2 );
63
61
vec_bf16 in31 = (vec_bf16 )vec_load_vec (in3 );
64
62
65
- __builtin_mma_xvbf16ger2pp (out + 0 , ( vec_uc8 ) in01 , ( vec_uc8 ) inp );
66
- __builtin_mma_xvbf16ger2pp ( out + 1 , ( vec_uc8 ) in11 , ( vec_uc8 ) inp );
63
+ vec_load_mult12a_mma (out , in0 , in1 , inp );
64
+
67
65
__builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 , (vec_uc8 )inp );
68
66
__builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 , (vec_uc8 )inp );
69
67
}
@@ -78,17 +76,21 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *
78
76
__builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [1 ], (vec_uc8 )inp [1 ]);
79
77
}
80
78
79
+ FORCEINLINE void vec_mult2d_mma (__vector_quad * out , vec_bf16 * in01 , vec_bf16 * in11 , vec_bf16 * inp )
80
+ {
81
+ __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
82
+ __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
83
+ }
84
+
81
85
FORCEINLINE void vec_load_mult22_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
82
86
{
83
87
vec_bf16 in01 [2 ], in11 [2 ];
84
88
85
89
vec_load_pair ((vec_f32 * )in01 , (vec_f32 * )in0 );
86
90
vec_load_pair ((vec_f32 * )in11 , (vec_f32 * )in1 );
87
91
88
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
89
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
90
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
91
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
92
+ vec_mult2d_mma (out , in01 + 0 , in11 + 0 , inp + 0 );
93
+ vec_mult2d_mma (out , in01 + 1 , in11 + 1 , inp + 1 );
92
94
}
93
95
94
96
FORCEINLINE void vec_load_mult24_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 * inp )
@@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
100
102
vec_load_pair ((vec_f32 * )in21 , (vec_f32 * )in2 );
101
103
vec_load_pair ((vec_f32 * )in31 , (vec_f32 * )in3 );
102
104
103
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
104
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
105
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [0 ], (vec_uc8 )inp [0 ]);
106
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [0 ], (vec_uc8 )inp [0 ]);
107
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
108
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
109
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [1 ], (vec_uc8 )inp [1 ]);
110
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [1 ], (vec_uc8 )inp [1 ]);
105
+ vec_mult2d_mma (out + 0 , in01 + 0 , in11 + 0 , inp + 0 );
106
+ vec_mult2d_mma (out + 2 , in21 + 0 , in31 + 0 , inp + 0 );
107
+ vec_mult2d_mma (out + 0 , in01 + 1 , in11 + 1 , inp + 1 );
108
+ vec_mult2d_mma (out + 2 , in21 + 1 , in31 + 1 , inp + 1 );
111
109
}
112
110
113
111
FORCEINLINE void vec_load_mult4_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 * inp )
114
112
{
115
- vec_bf16 in0 [4 ];
113
+ vec_bf16 in0 [2 ];
116
114
117
- vec_load_pair2 ( in0 , in );
115
+ vec_load_pair (( vec_f32 * )( in0 + 0 ), ( vec_f32 * )( in + 2 ) );
118
116
119
- __builtin_mma_xvbf16ger2pp (out , ( vec_uc8 ) in0 [ 0 ], ( vec_uc8 ) inp [ 0 ] );
120
- __builtin_mma_xvbf16ger2pp ( out , ( vec_uc8 ) in0 [ 1 ], ( vec_uc8 ) inp [ 1 ]);
121
- __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [2 ], (vec_uc8 )inp [2 ]);
122
- __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [3 ], (vec_uc8 )inp [3 ]);
117
+ vec_load_mult2_mma (out , in + 0 , inp + 0 );
118
+
119
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [0 ], (vec_uc8 )inp [2 ]);
120
+ __builtin_mma_xvbf16ger2pp (out , (vec_uc8 )in0 [1 ], (vec_uc8 )inp [3 ]);
123
121
}
124
122
125
123
FORCEINLINE void vec_load_mult42_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * inp )
@@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
129
127
vec_load_pair2 (in01 , in0 );
130
128
vec_load_pair2 (in11 , in1 );
131
129
132
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
133
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
134
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
135
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
136
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [2 ], (vec_uc8 )inp [2 ]);
137
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [2 ], (vec_uc8 )inp [2 ]);
138
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [3 ], (vec_uc8 )inp [3 ]);
139
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [3 ], (vec_uc8 )inp [3 ]);
130
+ vec_mult2d_mma (out , in01 + 0 , in11 + 0 , inp + 0 );
131
+ vec_mult2d_mma (out , in01 + 1 , in11 + 1 , inp + 1 );
132
+ vec_mult2d_mma (out , in01 + 2 , in11 + 2 , inp + 2 );
133
+ vec_mult2d_mma (out , in01 + 3 , in11 + 3 , inp + 3 );
134
+ }
135
+
136
+ FORCEINLINE void vec_mult4d_mma (__vector_quad * out , vec_bf16 * in01 , vec_bf16 * in11 , vec_bf16 * in21 , vec_bf16 * in31 , vec_bf16 * inp )
137
+ {
138
+ vec_mult2d_mma (out + 0 , in01 , in11 , inp );
139
+ vec_mult2d_mma (out + 2 , in21 , in31 , inp );
140
140
}
141
141
142
142
FORCEINLINE void vec_load_mult44_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 * inp )
@@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16
148
148
vec_load_pair2 (in21 , in2 );
149
149
vec_load_pair2 (in31 , in3 );
150
150
151
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [0 ], (vec_uc8 )inp [0 ]);
152
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [0 ], (vec_uc8 )inp [0 ]);
153
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [0 ], (vec_uc8 )inp [0 ]);
154
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [0 ], (vec_uc8 )inp [0 ]);
155
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [1 ], (vec_uc8 )inp [1 ]);
156
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [1 ], (vec_uc8 )inp [1 ]);
157
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [1 ], (vec_uc8 )inp [1 ]);
158
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [1 ], (vec_uc8 )inp [1 ]);
159
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [2 ], (vec_uc8 )inp [2 ]);
160
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [2 ], (vec_uc8 )inp [2 ]);
161
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [2 ], (vec_uc8 )inp [2 ]);
162
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [2 ], (vec_uc8 )inp [2 ]);
163
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 [3 ], (vec_uc8 )inp [3 ]);
164
- __builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 [3 ], (vec_uc8 )inp [3 ]);
165
- __builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 [3 ], (vec_uc8 )inp [3 ]);
166
- __builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 [3 ], (vec_uc8 )inp [3 ]);
151
+ vec_mult4d_mma (out , in01 + 0 , in11 + 0 , in21 + 0 , in31 + 0 , inp + 0 );
152
+ vec_mult4d_mma (out , in01 + 1 , in11 + 1 , in21 + 1 , in31 + 1 , inp + 1 );
153
+ vec_mult4d_mma (out , in01 + 2 , in11 + 2 , in21 + 2 , in31 + 2 , inp + 2 );
154
+ vec_mult4d_mma (out , in01 + 3 , in11 + 3 , in21 + 3 , in31 + 3 , inp + 3 );
167
155
}
168
156
169
157
FORCEINLINE void vec_loadN_mult_mma (__vector_quad * out , vec_bf16 * in , vec_bf16 inp , BLASLONG n )
@@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i
175
163
176
164
FORCEINLINE void vec_loadN_mult12a_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 inp , BLASLONG n )
177
165
{
178
- vec_bf16 in01 = (vec_bf16 )vec_loadN (in0 , n );
179
166
vec_bf16 in11 = (vec_bf16 )vec_loadN (in1 , n );
180
167
181
- __builtin_mma_xvbf16ger2pp (out + 0 , (vec_uc8 )in01 , (vec_uc8 )inp );
168
+ vec_loadN_mult_mma (out , in0 , inp , n );
169
+
182
170
__builtin_mma_xvbf16ger2pp (out + 1 , (vec_uc8 )in11 , (vec_uc8 )inp );
183
171
}
184
172
185
173
FORCEINLINE void vec_loadN_mult14_mma (__vector_quad * out , vec_bf16 * in0 , vec_bf16 * in1 , vec_bf16 * in2 , vec_bf16 * in3 , vec_bf16 inp , BLASLONG n )
186
174
{
187
- vec_bf16 in01 = (vec_bf16 )vec_loadN (in0 , n );
188
- vec_bf16 in11 = (vec_bf16 )vec_loadN (in1 , n );
189
175
vec_bf16 in21 = (vec_bf16 )vec_loadN (in2 , n );
190
176
vec_bf16 in31 = (vec_bf16 )vec_loadN (in3 , n );
191
177
192
- __builtin_mma_xvbf16ger2pp (out + 0 , ( vec_uc8 ) in01 , ( vec_uc8 ) inp );
193
- __builtin_mma_xvbf16ger2pp ( out + 1 , ( vec_uc8 ) in11 , ( vec_uc8 ) inp );
178
+ vec_loadN_mult12a_mma (out , in0 , in1 , inp , n );
179
+
194
180
__builtin_mma_xvbf16ger2pp (out + 2 , (vec_uc8 )in21 , (vec_uc8 )inp );
195
181
__builtin_mma_xvbf16ger2pp (out + 3 , (vec_uc8 )in31 , (vec_uc8 )inp );
196
182
}
0 commit comments