diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py index 4c2e7dd3..b007d785 100644 --- a/utils/codegen_tl1.py +++ b/utils/codegen_tl1.py @@ -201,10 +201,10 @@ def gen_body_core_code(bm, by): int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ - vec_c[{6}] += vec_v_left_{0}.val[0];\n\ - vec_c[{6}] += vec_v_right_{0}.val[0];\n\ - vec_c[{7}] += vec_v_left_{0}.val[1];\n\ - vec_c[{7}] += vec_v_right_{0}.val[1];\n\ + vec_c[{6}] += vreinterpretq_s16_s8(vec_v_left_{0}.val[0];\n\ + vec_c[{6}] += vreinterpretq_s16_s8(vec_v_right_{0}.val[0]);\n\ + vec_c[{7}] += vreinterpretq_s16_s8(vec_v_left_{0}.val[1]);\n\ + vec_c[{7}] += vreinterpretq_s16_s8(vec_v_right_{0}.val[1]);\n\ ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) all_code = "".join([all_code, core_code]) @@ -232,7 +232,7 @@ def gen_tbl_impl(pre, BM, BK, bm, k): #ifdef __ARM_NEON\n\ const int KK = BBK{0} / 2;\n\ const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ - const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ + const int8x16_t vec_zero = vdupq_n_s8(0x0000);\n\ int8x16_t vec_lut[2 * KK];\n\ ".format(pre, BM, BK) @@ -249,7 +249,7 @@ def gen_tbl_impl(pre, BM, BK, bm, k): for (int i = 0; i < BM{}; i += {}) {{\n\ #pragma unroll\n\ for (int i=0; i<{}; i++) {{\n\ - vec_c[i] = vandq_s16(vec_c[i], vec_zero);\n\ + vec_c[i] = vandq_s16(vec_c[i], vreinterpretq_s16_s8(vec_zero));\n\ }}\n".format(pre, bm, bm // 8) body_core_pre_code = "\n\