Skip to content

ggml : add ggml_scale_bias #14417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,19 @@ extern "C" {
struct ggml_tensor * a,
float s);

// x = s * a + b
GGML_API struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);

GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);

// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx,
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
Expand All @@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
Expand Down
28 changes: 20 additions & 8 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

// scale factor
float v;
memcpy(&v, dst->op_params, sizeof(float));
float s; // scale factor
float b; // bias

memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));

const int ith = params->ith;
const int nth = params->nth;
Expand All @@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(

const size_t nb1 = dst->nb[1];

for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
if (b == 0.0f) {
for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
}
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
}
} else {
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_mad1_f32(nc,
(float *) ((char *) dst->data + i1*nb1),
(float *) ((char *) src0->data + i1*nb1),
s, b);
}
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
}
}

Expand Down
39 changes: 39 additions & 0 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
#endif
}

inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
#if defined(GGML_USE_ACCELERATE)
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
#elif defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar ; TODO: Write SVE code
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));

GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);

GGML_F32_VEC ay[GGML_F32_ARR];

for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);

GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
}
}

// leftovers
for (int i = np; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
}

//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE)
Expand Down
14 changes: 8 additions & 6 deletions ggml/src/ggml-cuda/scale.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include "scale.cuh"

static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

dst[i] = scale * x[i];
dst[i] = scale * x[i] + bias;
}

static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
}

void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32);

float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));

scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
}
5 changes: 4 additions & 1 deletion ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node(
GGML_ASSERT(ggml_is_contiguous(src0));

float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
float bias;
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));

int64_t n = ggml_nelements(dst);

Expand All @@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1014,16 +1014,18 @@ kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
dst[tpig] = src0[tpig] * scale + bias;
}

kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
dst[tpig] = src0[tpig] * scale + bias;
}

kernel void kernel_clamp(
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
float bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
Expand All @@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));

int n = ggml_nelements(dst)/4;

Expand Down
5 changes: 3 additions & 2 deletions ggml/src/ggml-opencl/kernels/scale.cl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ kernel void kernel_scale(
ulong offset0,
global float4 * dst,
ulong offsetd,
float scale
float scale,
float bias
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
}
14 changes: 8 additions & 6 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}

static void scale_f32(const float * x, float * dst, const float scale, const int k,
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
Expand All @@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
return;
}

dst[i] = scale * x[i];
dst[i] = scale * x[i] + bias;
}


Expand Down Expand Up @@ -1842,15 +1842,15 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(



static void scale_f32_sycl(const float *x, float *dst, const float scale,
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
const int k, queue_ptr stream) {
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
scale_f32(x, dst, scale, k, item_ct1);
scale_f32(x, dst, scale, bias, k, item_ct1);
});
}

Expand Down Expand Up @@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
float * dst_dd = static_cast<float *>(dst->data);

float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));

scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
/*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code.
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f,
op_params[0], op_params[1],
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
}
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void main() {
continue;
}

data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
idx += num_threads;
}
}
28 changes: 23 additions & 5 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b,
bool inplace) {
GGML_ASSERT(ggml_is_padded_1d(a));

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

ggml_set_op_params(result, &s, sizeof(s));
float params[2] = { s, b };
ggml_set_op_params(result, &params, sizeof(params));

result->op = GGML_OP_SCALE;
result->src[0] = a;
Expand All @@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s) {
return ggml_scale_impl(ctx, a, s, false);
return ggml_scale_impl(ctx, a, s, 0.0, false);
}

struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s) {
return ggml_scale_impl(ctx, a, s, true);
return ggml_scale_impl(ctx, a, s, 0.0, true);
}

struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, false);
}

struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, true);
}

// ggml_set
Expand Down Expand Up @@ -5777,7 +5795,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_MEAN: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
}
} break;
case GGML_OP_REPEAT: {
Expand Down Expand Up @@ -5854,7 +5872,7 @@ static void ggml_compute_backward(
if (src0_needs_grads) {
float s;
memcpy(&s, tensor->op_params, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
}
} break;
case GGML_OP_SET: {
Expand Down
Loading
Loading