Skip to content

Commit f294964

Browse files
authored
[cherry-pick] fix sve backends bug(matmul_v2&conv) (#9696)
* add A510 for sdot supported (#9537) * [ARM] add matmul_v2 sve2 backends and add 5x5s1p2 max pooling (#9653) * [SVE] add matmul_v2 sve backends * [ARM] add 5x5s1p2 pooling max kernel * [sve] fix fuse leakrelu in conv (#9670) * [OpMakerClean] fix flatten op for removing xshape
1 parent 6bc3164 commit f294964

File tree

13 files changed

+788
-167
lines changed

13 files changed

+788
-167
lines changed

lite/backends/arm/math/gemm_s8.cc

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
// limitations under the License.
1414

1515
#include "lite/backends/arm/math/gemm_s8.h"
16+
#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
17+
#include "lite/backends/arm/math/sve/gemm_sve_i8mm.h"
18+
#endif
1619

1720
namespace paddle {
1821
namespace lite {
@@ -112,6 +115,113 @@ template void gemm_s8<int8_t>(bool is_transA,
112115
const operators::ActivationParam act_param,
113116
ARMContext* ctx);
114117

118+
#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
119+
template <typename Dtype>
120+
void gemm_sve(bool is_transA,
121+
bool is_transB,
122+
int M,
123+
int N,
124+
int K,
125+
const int8_t* A,
126+
const int8_t* B,
127+
Dtype* C,
128+
const float* bias,
129+
bool is_bias,
130+
const float* scale,
131+
const operators::ActivationParam act_param,
132+
ARMContext* ctx) {
133+
if (N == 1) {
134+
gemv_int8(A, B, C, is_transA, M, K, scale, is_bias, bias, act_param, ctx);
135+
return;
136+
}
137+
if (M == 1) {
138+
#ifdef TARGET_IOS
139+
float* bias_ptr = new float[N];
140+
float* scale_ptr = new float[N];
141+
#else
142+
float bias_ptr[N]; // NOLINT
143+
float scale_ptr[N]; // NOLINT
144+
#endif
145+
if (is_bias) {
146+
for (int i = 0; i < N; i++) {
147+
bias_ptr[i] = bias[0];
148+
}
149+
}
150+
for (int i = 0; i < N; i++) {
151+
scale_ptr[i] = scale[0];
152+
}
153+
gemv_int8(B,
154+
A,
155+
C,
156+
!is_transB,
157+
N,
158+
K,
159+
scale_ptr,
160+
is_bias,
161+
bias_ptr,
162+
act_param,
163+
ctx);
164+
#ifdef TARGET_IOS
165+
delete[] bias_ptr;
166+
delete[] scale_ptr;
167+
#endif
168+
return;
169+
}
170+
171+
//! prepack
172+
Tensor tpackedA_sve;
173+
int hblock_sve = paddle::lite::arm::math::sve::get_hblock_int8_sve(ctx);
174+
int round_up_a_sve = ((hblock_sve + M - 1) / hblock_sve) * hblock_sve;
175+
int round_up_k_sve = 8 * ((K + 7) / 8);
176+
tpackedA_sve.Resize({round_up_a_sve * round_up_k_sve});
177+
int lda = is_transA ? M : K;
178+
paddle::lite::arm::math::sve::prepackA_int8_sve(
179+
tpackedA_sve.mutable_data<int8_t>(), A, lda, 0, M, 0, K, is_transA, ctx);
180+
// sve
181+
lite::arm::math::sve::gemm_prepack_int8_sve<Dtype>(
182+
tpackedA_sve.data<int8_t>(),
183+
B,
184+
bias,
185+
C,
186+
M,
187+
N,
188+
K,
189+
is_bias,
190+
is_transB,
191+
scale,
192+
act_param,
193+
ctx);
194+
}
195+
196+
template void gemm_sve<float>(bool is_transA,
197+
bool is_transB,
198+
int M,
199+
int N,
200+
int K,
201+
const int8_t* A,
202+
const int8_t* B,
203+
float* C,
204+
const float* bias,
205+
bool is_bias,
206+
const float* scale,
207+
const operators::ActivationParam act_param,
208+
ARMContext* ctx);
209+
210+
template void gemm_sve<int8_t>(bool is_transA,
211+
bool is_transB,
212+
int M,
213+
int N,
214+
int K,
215+
const int8_t* A,
216+
const int8_t* B,
217+
int8_t* C,
218+
const float* bias,
219+
bool is_bias,
220+
const float* scale,
221+
const operators::ActivationParam act_param,
222+
ARMContext* ctx);
223+
#endif
224+
115225
} // namespace math
116226
} // namespace arm
117227
} // namespace lite

lite/backends/arm/math/gemm_s8.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ void gemm_s8(bool is_transA,
3939
const operators::ActivationParam act_param,
4040
ARMContext* ctx);
4141

42+
#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2)
43+
template <typename Dtype>
44+
void gemm_sve(bool is_transA,
45+
bool is_transB,
46+
int M,
47+
int N,
48+
int K,
49+
const int8_t* A,
50+
const int8_t* B,
51+
Dtype* C,
52+
const float* bias,
53+
bool is_bias,
54+
const float* scale,
55+
const operators::ActivationParam act_param,
56+
ARMContext* ctx);
57+
#endif
4258
} // namespace math
4359
} // namespace arm
4460
} // namespace lite

0 commit comments

Comments
 (0)