-
Notifications
You must be signed in to change notification settings - Fork 12.4k
ggml : add ggml_scale_bias #14417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add ggml_scale_bias #14417
Changes from 18 commits
50f88fc
7af3fd9
a5ccf16
e427af7
92a8738
a28df6f
782b58f
477a97a
0e51a0a
4d01953
b22708f
c8d8931
265cb43
563aca0
50c678f
0d70ca8
4ea74b0
cd1703a
ebbad77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int | |
#endif | ||
} | ||
|
||
inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, const float b) { | ||
#if defined(GGML_USE_ACCELERATE) | ||
vDSP_vsmsa(y, 1, &s, &b, y, 1, n); | ||
#elif defined(GGML_SIMD) | ||
#if defined(__ARM_FEATURE_SVE) | ||
// scalar ; TODO: Write SVE code | ||
for (int i = 0; i < n; ++i) { | ||
y[i] = y[i]*s + b; | ||
} | ||
#else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's ok for now. I'm having some doubts about these SVE branches - might end up removing them all together. |
||
const int np = (n & ~(GGML_F32_STEP - 1)); | ||
|
||
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); | ||
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); | ||
|
||
GGML_F32_VEC ay[GGML_F32_ARR]; | ||
|
||
for (int i = 0; i < np; i += GGML_F32_STEP) { | ||
for (int j = 0; j < GGML_F32_ARR; j++) { | ||
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); | ||
|
||
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); | ||
} | ||
} | ||
|
||
// leftovers | ||
for (int i = np; i < n; ++i) { | ||
y[i] = y[i]*s + b; | ||
} | ||
#endif | ||
#else | ||
// scalar | ||
for (int i = 0; i < n; ++i) { | ||
y[i] = y[i]*s + b; | ||
} | ||
#endif | ||
} | ||
|
||
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } | ||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { | ||
#if defined(GGML_USE_ACCELERATE) | ||
|
Uh oh!
There was an error while loading. Please reload this page.