|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 | 15 | #include "lite/backends/arm/math/gemm_s8.h"
|
| 16 | +#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) |
| 17 | +#include "lite/backends/arm/math/sve/gemm_sve_i8mm.h" |
| 18 | +#endif |
16 | 19 |
|
17 | 20 | namespace paddle {
|
18 | 21 | namespace lite {
|
@@ -112,6 +115,113 @@ template void gemm_s8<int8_t>(bool is_transA,
|
112 | 115 | const operators::ActivationParam act_param,
|
113 | 116 | ARMContext* ctx);
|
114 | 117 |
|
| 118 | +#if defined(__aarch64__) && defined(LITE_WITH_ARM8_SVE2) |
| 119 | +template <typename Dtype> |
| 120 | +void gemm_sve(bool is_transA, |
| 121 | + bool is_transB, |
| 122 | + int M, |
| 123 | + int N, |
| 124 | + int K, |
| 125 | + const int8_t* A, |
| 126 | + const int8_t* B, |
| 127 | + Dtype* C, |
| 128 | + const float* bias, |
| 129 | + bool is_bias, |
| 130 | + const float* scale, |
| 131 | + const operators::ActivationParam act_param, |
| 132 | + ARMContext* ctx) { |
| 133 | + if (N == 1) { |
| 134 | + gemv_int8(A, B, C, is_transA, M, K, scale, is_bias, bias, act_param, ctx); |
| 135 | + return; |
| 136 | + } |
| 137 | + if (M == 1) { |
| 138 | +#ifdef TARGET_IOS |
| 139 | + float* bias_ptr = new float[N]; |
| 140 | + float* scale_ptr = new float[N]; |
| 141 | +#else |
| 142 | + float bias_ptr[N]; // NOLINT |
| 143 | + float scale_ptr[N]; // NOLINT |
| 144 | +#endif |
| 145 | + if (is_bias) { |
| 146 | + for (int i = 0; i < N; i++) { |
| 147 | + bias_ptr[i] = bias[0]; |
| 148 | + } |
| 149 | + } |
| 150 | + for (int i = 0; i < N; i++) { |
| 151 | + scale_ptr[i] = scale[0]; |
| 152 | + } |
| 153 | + gemv_int8(B, |
| 154 | + A, |
| 155 | + C, |
| 156 | + !is_transB, |
| 157 | + N, |
| 158 | + K, |
| 159 | + scale_ptr, |
| 160 | + is_bias, |
| 161 | + bias_ptr, |
| 162 | + act_param, |
| 163 | + ctx); |
| 164 | +#ifdef TARGET_IOS |
| 165 | + delete[] bias_ptr; |
| 166 | + delete[] scale_ptr; |
| 167 | +#endif |
| 168 | + return; |
| 169 | + } |
| 170 | + |
| 171 | + //! prepack |
| 172 | + Tensor tpackedA_sve; |
| 173 | + int hblock_sve = paddle::lite::arm::math::sve::get_hblock_int8_sve(ctx); |
| 174 | + int round_up_a_sve = ((hblock_sve + M - 1) / hblock_sve) * hblock_sve; |
| 175 | + int round_up_k_sve = 8 * ((K + 7) / 8); |
| 176 | + tpackedA_sve.Resize({round_up_a_sve * round_up_k_sve}); |
| 177 | + int lda = is_transA ? M : K; |
| 178 | + paddle::lite::arm::math::sve::prepackA_int8_sve( |
| 179 | + tpackedA_sve.mutable_data<int8_t>(), A, lda, 0, M, 0, K, is_transA, ctx); |
| 180 | + // sve |
| 181 | + lite::arm::math::sve::gemm_prepack_int8_sve<Dtype>( |
| 182 | + tpackedA_sve.data<int8_t>(), |
| 183 | + B, |
| 184 | + bias, |
| 185 | + C, |
| 186 | + M, |
| 187 | + N, |
| 188 | + K, |
| 189 | + is_bias, |
| 190 | + is_transB, |
| 191 | + scale, |
| 192 | + act_param, |
| 193 | + ctx); |
| 194 | +} |
| 195 | + |
| 196 | +template void gemm_sve<float>(bool is_transA, |
| 197 | + bool is_transB, |
| 198 | + int M, |
| 199 | + int N, |
| 200 | + int K, |
| 201 | + const int8_t* A, |
| 202 | + const int8_t* B, |
| 203 | + float* C, |
| 204 | + const float* bias, |
| 205 | + bool is_bias, |
| 206 | + const float* scale, |
| 207 | + const operators::ActivationParam act_param, |
| 208 | + ARMContext* ctx); |
| 209 | + |
| 210 | +template void gemm_sve<int8_t>(bool is_transA, |
| 211 | + bool is_transB, |
| 212 | + int M, |
| 213 | + int N, |
| 214 | + int K, |
| 215 | + const int8_t* A, |
| 216 | + const int8_t* B, |
| 217 | + int8_t* C, |
| 218 | + const float* bias, |
| 219 | + bool is_bias, |
| 220 | + const float* scale, |
| 221 | + const operators::ActivationParam act_param, |
| 222 | + ARMContext* ctx); |
| 223 | +#endif |
| 224 | + |
115 | 225 | } // namespace math
|
116 | 226 | } // namespace arm
|
117 | 227 | } // namespace lite
|
|
0 commit comments