Skip to content

Commit 241c357

Browse files
feature: add support of RISCV64_SPACEMIT_IME2
Change-Id: I07c3e0dbb9bc10a11bcb92df1bbad75077c0e06a
1 parent 48eed6e commit 241c357

File tree

3 files changed

+103
-11
lines changed

3 files changed

+103
-11
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
2121
ggml-cpu/ggml-cpu-traits.h
2222
ggml-cpu/ggml-cpu-impl.h
2323
)
24-
24+
2525
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
2626
include(FetchContent)
2727
# TODO replace with git repo
2828
FetchContent_Declare(
2929
onnxruntime
3030
GIT_REPOSITORY ssh://$ENV{GERRIT_USER}@gerrit.dc.com:29418/DSA/onnxruntime
31-
GIT_TAG "c17089e2e45067e24911d95611d2196a3dd63694"
31+
GIT_TAG "7935d26a2ef0afa307e39b4c8a2ed438d281e5bd"
3232
)
3333
# FetchContent_Declare(
3434
# onnxruntime
@@ -348,6 +348,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
348348
message(STATUS "RISC-V detected")
349349
if (GGML_RVV)
350350
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
351+
list(APPEND ARCH_DEFINITIONS RISCV64_SPACEMIT_IME2)
351352
endif()
352353
else()
353354
message(STATUS "Unknown architecture")

ggml/src/ggml-cpu/ggml-cpu-riscv64-spacemit.cpp

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ void SQ4BitGemm_CompInt8(
8080

8181
size_t CountN;
8282

83+
#if defined(RISCV64_SPACEMIT_IME1)
8384
const size_t ComputeBlockCountN = RangeCountM == 1 ? RangeCountN : 16;
85+
#elif defined(RISCV64_SPACEMIT_IME2)
86+
const size_t ComputeBlockCountN = RangeCountM == 1 ? RangeCountN : 32;
87+
#endif
8488

8589
for (size_t n = 0; n < RangeCountN; n += CountN) {
8690
CountN = std::min(RangeCountN - n, ComputeBlockCountN);
@@ -279,6 +283,8 @@ struct block {
279283
};
280284

281285
// control size
286+
static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding");
287+
static_assert(sizeof(block<8, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 32, "wrong block<8,32> size/padding");
282288
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
283289
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
284290

@@ -296,24 +302,55 @@ static block_q4_0x16 make_block_q4_0x16(block_q4_0* in, unsigned int blck_size_i
296302
for (int i = 0; i < 16; i++) {
297303
// [0, 15], in.d & 0x0F
298304
for (int j = 0; j < QK4_0 / 4; j++) {
299-
// [b0 b16] ......... [b8 b24] ......... [b15 b31]
300-
// [b0 b8] ......... [b7 b15]
305+
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
306+
//dst [b0 b8] ......... [b7 b15]
301307
out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
302308
}
303309
}
304310

305311
for (int i = 0; i < 16; i++) {
306312
// [16, 31], in.d & 0xF0
307313
for (int j = 0; j < QK4_0 / 4; j++) {
308-
// [b0 b16] ......... [b8 b24] ......... [b15 b31]
309-
// [b16 b24] ......... [b23 b31]
314+
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
315+
//dst [b16 b24] ......... [b23 b31]
310316
out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
311317
}
312318
}
313319

314320
return out;
315321
}
316322

323+
using block_q4_0x32 = block<4, 32>;
324+
using block_q8_0x32 = block<8, 32>;
325+
static block_q4_0x32 make_block_q4_0x32(block_q4_0* in, unsigned int blck_size_interleave) {
326+
block_q4_0x32 out;
327+
assert(QK4_0 / blck_size_interleave == 1);
328+
329+
for (int i = 0; i < 32; i++) { // zhaolikun [check]
330+
out.d[i] = in[i].d;
331+
}
332+
333+
for (int i = 0; i < 32; i++) {
334+
// [0, 15], in.d & 0x0F
335+
for (int j = 0; j < QK4_0/4; j++) {
336+
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
337+
//dst [b0 b1] ......... [b14 b15]
338+
out.qs[i * QK4_0/2 + j] = (in[i].qs[j*2] & 0x0F) | ((in[i].qs[j*2 + 1] & 0x0F) << 4);
339+
}
340+
}
341+
342+
for (int i = 0; i < 32; i++) {
343+
// [16, 31], in.d & 0xF0
344+
for (int j = 0; j < QK4_0/4; j++) {
345+
//src [b0 b16] ......... [b8 b24] ......... [b15 b31]
346+
//dst [b16 b17] ......... [b30 b31]
347+
out.qs[i * QK4_0/2 + QK4_0/4 + j] = ((in[i].qs[j*2] & 0xF0)>>4) | (in[i].qs[j*2 + 1] & 0xF0);
348+
}
349+
}
350+
351+
return out;
352+
}
353+
317354
static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) {
318355
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
319356
GGML_ASSERT(interleave_block == 16);
@@ -346,6 +383,38 @@ static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor* t, int interleave_block
346383
GGML_UNUSED(data_size);
347384
}
348385

386+
static int repack_q4_0_to_q4_0_32_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) {
387+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
388+
GGML_ASSERT(interleave_block == 32); // unused
389+
390+
constexpr int nrows_interleaved = 32;
391+
392+
block_q4_0x32* dst = (block_q4_0x32*)t->data;
393+
const block_q4_0* src = (const block_q4_0*)data;
394+
block_q4_0 dst_tmp[32];
395+
int nrow = ggml_nrows(t);
396+
int nblocks = t->ne[0] / QK4_0;
397+
398+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
399+
400+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
401+
return -1;
402+
}
403+
404+
for (int b = 0; b < nrow; b += nrows_interleaved) {
405+
for (int64_t x = 0; x < nblocks; x++) {
406+
for (int i = 0; i < nrows_interleaved; i++) {
407+
dst_tmp[i] = src[x + i * nblocks];
408+
}
409+
*dst++ = make_block_q4_0x32(dst_tmp, interleave_block);
410+
}
411+
src += nrows_interleaved * nblocks;
412+
}
413+
return 0;
414+
415+
GGML_UNUSED(data_size);
416+
}
417+
349418
namespace ggml::cpu::riscv64_spacemit {
350419

351420
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
@@ -355,6 +424,10 @@ template <>
355424
int repack<block_q4_0, 8, 16>(struct ggml_tensor* t, const void* data, size_t data_size) {
356425
return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
357426
}
427+
template <>
428+
int repack<block_q4_0, 16, 32>(struct ggml_tensor* t, const void* data, size_t data_size) {
429+
return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size);
430+
}
358431

359432
class tensor_traits_base : public ggml::cpu::tensor_traits {
360433
public:
@@ -707,15 +780,22 @@ class tensor_traits_common : public tensor_traits_base {
707780
};
708781

709782
static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
783+
static const tensor_traits<block_q4_0, 16, 32> q4_0_32x16_q8_0;
710784
static const tensor_traits_common rvv_impl;
711785

712786
} // namespace ggml::cpu::riscv64_spacemit
713787

714788
static const ggml::cpu::tensor_traits* ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor* cur) {
715789
if (cur->type == GGML_TYPE_Q4_0) {
716-
if (cur->ne[1] % 16 == 0) {
717-
return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
718-
}
790+
#if defined(RISCV64_SPACEMIT_IME1)
791+
if (cur->ne[1] % 16 == 0) {
792+
return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
793+
}
794+
#elif defined(RISCV64_SPACEMIT_IME2)
795+
if (cur->ne[1] % 32 == 0) {
796+
return &ggml::cpu::riscv64_spacemit::q4_0_32x16_q8_0;
797+
}
798+
#endif
719799
} else if (cur->type == GGML_TYPE_F32) {
720800
return &ggml::cpu::riscv64_spacemit::rvv_impl;
721801
}

ggml/src/ggml-cpu/onnxruntime_mlas/CMakeLists.txt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
5959
set(ARM64 TRUE)
6060
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*")
6161
set(RISCV64 TRUE)
62-
set(RISCV64_SPACEMIT_IME_SPEC RISCV64_SPACEMIT_IME1)
62+
set(RISCV64_SPACEMIT_IME_SPEC RISCV64_SPACEMIT_IME2)
6363
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
6464
set(X86 TRUE)
6565
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
@@ -768,7 +768,7 @@ endif()
768768
${MLAS_SRC_DIR}/layernorm_rvv.cpp
769769
)
770770

771-
if (RISCV64_SPACEMIT_IME_SPEC)
771+
if (RISCV64_SPACEMIT_IME_SPEC STREQUAL "RISCV64_SPACEMIT_IME1")
772772
target_compile_definitions(onnxruntime_mlas PRIVATE ${RISCV64_SPACEMIT_IME_SPEC})
773773
set(mlas_platform_srcs
774774
${mlas_platform_srcs}
@@ -779,6 +779,17 @@ endif()
779779
)
780780
endif()
781781

782+
if (RISCV64_SPACEMIT_IME_SPEC STREQUAL "RISCV64_SPACEMIT_IME2")
783+
target_compile_definitions(onnxruntime_mlas PRIVATE ${RISCV64_SPACEMIT_IME_SPEC})
784+
set(mlas_platform_srcs
785+
${mlas_platform_srcs}
786+
${MLAS_SRC_DIR}/qgemm_kernel_spacemit_ime2.cpp
787+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_spacemit_ime2.cpp
788+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_spacemit_ime2_int8.cpp
789+
${MLAS_SRC_DIR}/sqnbitgemm_kernel_spacemit_ime_fp32.cpp
790+
)
791+
endif()
792+
782793
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH)
783794
set(MLAS_SOURCE_IS_NOT_SET 0)
784795
endif()

0 commit comments

Comments
 (0)