Skip to content

Commit 3e303b1

Browse files
Aclyggerganov
authored andcommitted
vulkan : implement ggml_roll (ggml/1290)
ggml-ci
1 parent 0c1df14 commit 3e303b1

File tree

4 files changed

+154
-95
lines changed

4 files changed

+154
-95
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 79 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ struct vk_device_struct {
432432
vk_pipeline pipeline_cos_f32;
433433
vk_pipeline pipeline_clamp_f32;
434434
vk_pipeline pipeline_pad_f32;
435+
vk_pipeline pipeline_roll_f32;
435436
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
436437
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
437438
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
@@ -694,6 +695,37 @@ struct vk_op_unary_push_constants {
694695
};
695696
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
696697

698+
static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
699+
GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
700+
ne = ne != 0 ? ne : ggml_nelements(dst);
701+
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
702+
703+
vk_op_unary_push_constants p{};
704+
p.ne = (uint32_t)ne;
705+
706+
size_t src0_tsize = ggml_type_size(src0->type);
707+
p.ne00 = (uint32_t)src0->ne[0];
708+
p.ne01 = (uint32_t)src0->ne[1];
709+
p.ne02 = (uint32_t)src0->ne[2];
710+
p.ne03 = (uint32_t)src0->ne[3];
711+
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
712+
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
713+
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
714+
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
715+
716+
size_t dst_tsize = ggml_type_size(dst->type);
717+
p.ne10 = (uint32_t)dst->ne[0];
718+
p.ne11 = (uint32_t)dst->ne[1];
719+
p.ne12 = (uint32_t)dst->ne[2];
720+
p.ne13 = (uint32_t)dst->ne[3];
721+
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
722+
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
723+
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
724+
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
725+
726+
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
727+
}
728+
697729
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
698730
// Precompute mp (m' in the paper) and L such that division
699731
// can be computed using a multiply (high 32b of 64b result)
@@ -2836,6 +2868,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
28362868

28372869
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28382870

2871+
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2872+
28392873
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28402874
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28412875

@@ -6536,6 +6570,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65366570
return ctx->device->pipeline_pad_f32;
65376571
}
65386572
return nullptr;
6573+
case GGML_OP_ROLL:
6574+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6575+
return ctx->device->pipeline_roll_f32;
6576+
}
6577+
return nullptr;
65396578
case GGML_OP_REPEAT:
65406579
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
65416580
return ctx->device->pipeline_repeat_f32;
@@ -7085,6 +7124,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70857124
case GGML_OP_COS:
70867125
case GGML_OP_CLAMP:
70877126
case GGML_OP_PAD:
7127+
case GGML_OP_ROLL:
70887128
case GGML_OP_REPEAT:
70897129
case GGML_OP_REPEAT_BACK:
70907130
case GGML_OP_CPY:
@@ -7561,117 +7601,61 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
75617601
}
75627602

75637603
static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7564-
float * op_params = (float *)dst->op_params;
7565-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7566-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7604+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7605+
p.param1 = ggml_get_op_params_f32(dst, 0);
7606+
p.param2 = ggml_get_op_params_f32(dst, 1);
75677607

7568-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
7569-
(uint32_t)ggml_nelements(src0),
7570-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7571-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7572-
0,
7573-
op_params[0], op_params[1],
7574-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7575-
}, dryrun);
7608+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun);
75767609
}
75777610

75787611
static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7579-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7580-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7581-
7582-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
7583-
(uint32_t)ggml_nelements(src0),
7584-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7585-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7586-
0,
7587-
0.0f, 0.0f,
7588-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7589-
}, dryrun);
7612+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
75907613
}
75917614

75927615
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7593-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7594-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7595-
7596-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
7597-
(uint32_t)ggml_nelements(src0),
7598-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7599-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7600-
0,
7601-
0.0f, 0.0f,
7602-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7603-
}, dryrun);
7616+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
76047617
}
76057618

76067619
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7607-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7608-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7609-
7610-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
7611-
(uint32_t)ggml_nelements(src0),
7612-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7613-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7614-
0,
7615-
0.0f, 0.0f,
7616-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7617-
}, dryrun);
7620+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun);
76187621
}
76197622

76207623
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7621-
float * op_params = (float *)dst->op_params;
7622-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7623-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7624+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7625+
p.param1 = ggml_get_op_params_f32(dst, 0);
7626+
p.param2 = ggml_get_op_params_f32(dst, 1);
76247627

7625-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
7626-
(uint32_t)ggml_nelements(src0),
7627-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7628-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7629-
0,
7630-
op_params[0], op_params[1],
7631-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7632-
}, dryrun);
7628+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun);
76337629
}
76347630

76357631
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7636-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7637-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7632+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7633+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
7634+
}
76387635

7639-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, {
7640-
(uint32_t)ggml_nelements(dst),
7641-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7642-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7643-
0,
7644-
0.0f, 0.0f,
7645-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7646-
}, dryrun);
7636+
static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7637+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
7638+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
7639+
const int32_t s2 = ggml_get_op_params_i32(dst, 2);
7640+
const int32_t s3 = ggml_get_op_params_i32(dst, 3);
7641+
const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
7642+
const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
7643+
7644+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
7645+
memcpy(&p.param1, &s01_packed, sizeof(float));
7646+
memcpy(&p.param2, &s23_packed, sizeof(float));
7647+
7648+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun);
76477649
}
76487650

76497651
static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7650-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7651-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7652-
7653-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, {
7654-
(uint32_t)ggml_nelements(dst),
7655-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7656-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7657-
0,
7658-
0.0f, 0.0f,
7659-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7660-
}, dryrun);
7652+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7653+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun);
76617654
}
76627655

76637656
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7664-
const uint32_t src0_type_size = ggml_type_size(src0->type);
7665-
const uint32_t dst_type_size = ggml_type_size(dst->type);
7666-
7667-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
7668-
(uint32_t)ggml_nelements(dst),
7669-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7670-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7671-
0,
7672-
0.0f, 0.0f,
7673-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7674-
}, dryrun);
7657+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
7658+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun);
76757659
}
76767660

76777661
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -7689,14 +7673,8 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
76897673
}
76907674
}
76917675

7692-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7693-
ne,
7694-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7695-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7696-
0,
7697-
0.0f, 0.0f,
7698-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7699-
}, dryrun);
7676+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
7677+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun);
77007678
}
77017679

77027680
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9033,6 +9011,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90339011
case GGML_OP_COS:
90349012
case GGML_OP_CLAMP:
90359013
case GGML_OP_PAD:
9014+
case GGML_OP_ROLL:
90369015
case GGML_OP_CPY:
90379016
case GGML_OP_SET_ROWS:
90389017
case GGML_OP_CONT:
@@ -9204,6 +9183,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
92049183
case GGML_OP_PAD:
92059184
ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun);
92069185

9186+
break;
9187+
case GGML_OP_ROLL:
9188+
ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun);
9189+
92079190
break;
92089191
case GGML_OP_CPY:
92099192
case GGML_OP_CONT:
@@ -9428,6 +9411,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
94289411
case GGML_OP_COS:
94299412
case GGML_OP_CLAMP:
94309413
case GGML_OP_PAD:
9414+
case GGML_OP_ROLL:
94319415
case GGML_OP_CPY:
94329416
case GGML_OP_SET_ROWS:
94339417
case GGML_OP_CONT:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
uint wrap_idx(int i, uint ne) {
9+
if (i < 0) {
10+
return i + ne;
11+
} else if (i >= ne) {
12+
return i - ne;
13+
}
14+
return i;
15+
}
16+
17+
void main() {
18+
const uint idx = get_idx();
19+
if (idx >= p.ne) {
20+
return;
21+
}
22+
23+
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
24+
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
25+
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
26+
const uint i2_offset = i2*p.ne11*p.ne10;
27+
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
28+
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
29+
30+
const uint p1 = floatBitsToUint(p.param1);
31+
const uint p2 = floatBitsToUint(p.param2);
32+
const int s0 = int(p1 >> 16) - 0x8000;
33+
const int s1 = int(p1 & 0xFFFF) - 0x8000;
34+
const int s2 = int(p2 >> 16) - 0x8000;
35+
const int s3 = int(p2 & 0xFFFF) - 0x8000;
36+
37+
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
38+
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
39+
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
40+
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
41+
42+
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
43+
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
44+
45+
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
46+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,8 @@ void process_shaders() {
653653
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
654654
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
655655

656+
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
657+
656658
for (auto &c : compiles) {
657659
c.wait();
658660
}

0 commit comments

Comments
 (0)