-
Notifications
You must be signed in to change notification settings - Fork 12.4k
ggml : add ggml_scale_bias #14417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
ggml : add ggml_scale_bias #14417
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
50f88fc
ggml : add ggml_scale_bias
ngxson 7af3fd9
Merge branch 'master' into xsn/ggml_scale_bias
ngxson a5ccf16
ggml_vec_mad1_f32
ngxson e427af7
add more simd
ngxson 92a8738
add CUDA
ngxson a28df6f
sycl
ngxson 782b58f
vulkan
ngxson 477a97a
cann (placeholder)
ngxson 0e51a0a
opencl
ngxson 4d01953
will this fix cpu?
ngxson b22708f
fix cuda
ngxson c8d8931
suggestions from coderabbit
ngxson 265cb43
fix cann compile error
ngxson 563aca0
vDSP_vsmsa
ngxson 50c678f
rm __ARM_FEATURE_SVE
ngxson 0d70ca8
use memcpy for op params
ngxson 4ea74b0
make code looks more consistent
ngxson cd1703a
use scalar for __ARM_FEATURE_SVE
ngxson ebbad77
add x param to ggml_vec_mad1_f32
ngxson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
#include "scale.cuh" | ||
|
||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { | ||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { | ||
const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||
|
||
if (i >= k) { | ||
return; | ||
} | ||
|
||
dst[i] = scale * x[i]; | ||
dst[i] = scale * x[i] + bias; | ||
} | ||
|
||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { | ||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { | ||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; | ||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); | ||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k); | ||
} | ||
|
||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
|
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||
|
||
float scale; | ||
float bias; | ||
memcpy(&scale, dst->op_params, sizeof(float)); | ||
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make the this more consistent: memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); |
||
|
||
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); | ||
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.