Skip to content

Commit 91e8a20

Browse files
committed
Merge branch 'concedo_experimental' into crokeso
2 parents a1cbee1 + c4df151 commit 91e8a20

30 files changed

+499
-244
lines changed

CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ if (LLAMA_CUBLAS)
170170
enable_language(CUDA)
171171

172172
add_compile_definitions(GGML_USE_LLAMAFILE)
173+
add_compile_definitions(GGML_USE_CUBLAS)
173174
add_compile_definitions(GGML_USE_CUDA)
174-
add_compile_definitions(SD_USE_CUBLAS)
175-
175+
176176
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
177177
# Options are:
178178
# - none (not recommended)
@@ -182,6 +182,9 @@ if (LLAMA_CUBLAS)
182182
list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
183183
endif()
184184

185+
add_compile_definitions(SD_USE_CUBLAS)
186+
add_compile_definitions(SD_USE_CUDA)
187+
185188
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
186189
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
187190
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
@@ -376,7 +379,7 @@ if (LLAMA_HIPBLAS)
376379
list(APPEND GGML_SOURCES_ROCM ${SRCS})
377380
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
378381
list(APPEND GGML_SOURCES_ROCM ${SRCS})
379-
add_compile_definitions(GGML_USE_HIP GGML_USE_CUDA SD_USE_CUBLAS)
382+
add_compile_definitions(GGML_USE_HIP GGML_USE_CUDA SD_USE_CUDA)
380383
add_library(ggml-rocm ${GGML_SOURCES_CUDA})
381384
if (LLAMA_CUDA_FORCE_DMMV)
382385
target_compile_definitions(ggml-rocm PUBLIC GGML_CUDA_FORCE_DMMV)

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ CLBLAST_FLAGS = -DGGML_USE_CLBLAST
8989
FAILSAFE_FLAGS = -DUSE_FAILSAFE
9090
VULKAN_FLAGS = -DGGML_USE_VULKAN -DSD_USE_VULKAN
9191
ifdef LLAMA_CUBLAS
92-
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS
92+
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUDA
9393
else
9494
CUBLAS_FLAGS =
9595
endif
@@ -228,7 +228,7 @@ else
228228
endif # LLAMA_CUDA_FA_ALL_QUANTS
229229

230230
ifdef LLAMA_CUBLAS
231-
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
231+
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
232232
CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/local/cuda/targets/sbsa-linux/lib -L/usr/lib/wsl/lib
233233
CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
234234
CUBLAS_OBJS += $(patsubst %.cu,%.o,$(filter-out ggml/src/ggml-cuda/ggml-cuda.cu, $(wildcard ggml/src/ggml-cuda/*.cu)))
@@ -328,7 +328,7 @@ ifdef DETECT_ROCWMMA
328328
HIPFLAGS += -DGGML_HIP_ROCWMMA_FATTN -I$(dir $(DETECT_ROCWMMA))
329329
endif
330330

331-
HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_USE_CUDA -DSD_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
331+
HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_USE_CUDA -DSD_USE_CUDA $(shell $(ROCM_PATH)/bin/hipconfig -C)
332332
HIPLDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
333333
HIPLDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64
334334
HIPLDFLAGS += -lhipblas -lamdhip64 -lrocblas

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
16801680
[](common_params & params) {
16811681
params.warmup = false;
16821682
}
1683-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
1683+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
16841684
add_opt(common_arg(
16851685
{"--spm-infill"},
16861686
string_format(

convert_hf_to_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2645,7 +2645,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26452645
yield from super().modify_tensors(data_torch, name, bid)
26462646

26472647

2648-
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
2648+
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
26492649
class Qwen2VLModel(TextModel):
26502650
model_arch = gguf.MODEL_ARCH.QWEN2VL
26512651

@@ -2669,7 +2669,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26692669
return [(self.map_tensor_name(name), data_torch)]
26702670

26712671

2672-
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
2672+
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
26732673
class Qwen2VLVisionModel(VisionModel):
26742674
def __init__(self, *args, **kwargs):
26752675
super().__init__(*args, **kwargs)

expose.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct load_model_inputs
7777
const int draft_quant_k = -1;
7878
const int draft_quant_v = -1;
7979
const bool check_slowness = false;
80+
const bool swa_support = false;
8081
const bool quiet = false;
8182
const int debugmode = 0;
8283
};

ggml/include/ggml.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,14 +651,15 @@ extern "C" {
651651
GGML_UNARY_OP_STEP,
652652
GGML_UNARY_OP_TANH,
653653
GGML_UNARY_OP_ELU,
654-
GGML_UNARY_OP_RELU,
655654
GGML_UNARY_OP_SIGMOID,
656655
GGML_UNARY_OP_GELU,
656+
GGML_UNARY_OP_GELU_ERF,
657657
GGML_UNARY_OP_GELU_QUICK,
658658
GGML_UNARY_OP_SILU,
659659
GGML_UNARY_OP_HARDSWISH,
660660
GGML_UNARY_OP_HARDSIGMOID,
661661
GGML_UNARY_OP_EXP,
662+
GGML_UNARY_OP_RELU,
662663

663664
GGML_UNARY_OP_COUNT,
664665
};
@@ -1152,6 +1153,16 @@ extern "C" {
11521153
struct ggml_context * ctx,
11531154
struct ggml_tensor * a);
11541155

1156+
// GELU using erf (error function) when possible
1157+
// some backends may fallback to approximation based on Abramowitz and Stegun formula
1158+
GGML_API struct ggml_tensor * ggml_gelu_erf(
1159+
struct ggml_context * ctx,
1160+
struct ggml_tensor * a);
1161+
1162+
GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
1163+
struct ggml_context * ctx,
1164+
struct ggml_tensor * a);
1165+
11551166
GGML_API struct ggml_tensor * ggml_gelu_quick(
11561167
struct ggml_context * ctx,
11571168
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
25792579
} break;
25802580

25812581
case GGML_UNARY_OP_GELU:
2582+
case GGML_UNARY_OP_GELU_ERF:
25822583
case GGML_UNARY_OP_GELU_QUICK:
25832584
case GGML_UNARY_OP_SILU:
25842585
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,6 +2824,109 @@ static void ggml_compute_forward_gelu(
28242824
}
28252825
}
28262826

2827+
// ggml_compute_forward_gelu_erf
2828+
2829+
static void ggml_compute_forward_gelu_erf_f32(
2830+
const ggml_compute_params * params,
2831+
ggml_tensor * dst) {
2832+
2833+
const ggml_tensor * src0 = dst->src[0];
2834+
2835+
assert(ggml_is_contiguous_1(src0));
2836+
assert(ggml_is_contiguous_1(dst));
2837+
assert(ggml_are_same_shape(src0, dst));
2838+
2839+
const int ith = params->ith;
2840+
const int nth = params->nth;
2841+
2842+
const int nc = src0->ne[0];
2843+
const int nr = ggml_nrows(src0);
2844+
2845+
// rows per thread
2846+
const int dr = (nr + nth - 1)/nth;
2847+
2848+
// row range for this thread
2849+
const int ir0 = dr*ith;
2850+
const int ir1 = MIN(ir0 + dr, nr);
2851+
2852+
for (int i1 = ir0; i1 < ir1; i1++) {
2853+
ggml_vec_gelu_erf_f32(nc,
2854+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
2855+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
2856+
2857+
#ifndef NDEBUG
2858+
for (int k = 0; k < nc; k++) {
2859+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2860+
GGML_UNUSED(x);
2861+
assert(!isnan(x));
2862+
assert(!isinf(x));
2863+
}
2864+
#endif
2865+
}
2866+
}
2867+
2868+
static void ggml_compute_forward_gelu_erf_f16(
2869+
const ggml_compute_params * params,
2870+
ggml_tensor * dst) {
2871+
2872+
const ggml_tensor * src0 = dst->src[0];
2873+
2874+
assert(ggml_is_contiguous_1(src0));
2875+
assert(ggml_is_contiguous_1(dst));
2876+
assert(ggml_are_same_shape(src0, dst));
2877+
2878+
const int ith = params->ith;
2879+
const int nth = params->nth;
2880+
2881+
const int nc = src0->ne[0];
2882+
const int nr = ggml_nrows(src0);
2883+
2884+
// rows per thread
2885+
const int dr = (nr + nth - 1)/nth;
2886+
2887+
// row range for this thread
2888+
const int ir0 = dr*ith;
2889+
const int ir1 = MIN(ir0 + dr, nr);
2890+
2891+
for (int i1 = ir0; i1 < ir1; i1++) {
2892+
ggml_vec_gelu_erf_f16(nc,
2893+
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2894+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2895+
2896+
#ifndef NDEBUG
2897+
for (int k = 0; k < nc; k++) {
2898+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2899+
const float v = GGML_FP16_TO_FP32(x);
2900+
GGML_UNUSED(v);
2901+
assert(!isnan(v));
2902+
assert(!isinf(v));
2903+
}
2904+
#endif
2905+
}
2906+
}
2907+
2908+
static void ggml_compute_forward_gelu_erf(
2909+
const ggml_compute_params * params,
2910+
ggml_tensor * dst) {
2911+
2912+
const ggml_tensor * src0 = dst->src[0];
2913+
2914+
switch (src0->type) {
2915+
case GGML_TYPE_F32:
2916+
{
2917+
ggml_compute_forward_gelu_erf_f32(params, dst);
2918+
} break;
2919+
case GGML_TYPE_F16:
2920+
{
2921+
ggml_compute_forward_gelu_erf_f16(params, dst);
2922+
} break;
2923+
default:
2924+
{
2925+
GGML_ABORT("fatal error");
2926+
}
2927+
}
2928+
}
2929+
28272930
// ggml_compute_forward_gelu_quick
28282931

28292932
static void ggml_compute_forward_gelu_quick_f32(
@@ -8253,6 +8356,10 @@ void ggml_compute_forward_unary(
82538356
{
82548357
ggml_compute_forward_gelu(params, dst);
82558358
} break;
8359+
case GGML_UNARY_OP_GELU_ERF:
8360+
{
8361+
ggml_compute_forward_gelu_erf(params, dst);
8362+
} break;
82568363
case GGML_UNARY_OP_GELU_QUICK:
82578364
{
82588365
ggml_compute_forward_gelu_quick(params, dst);

ggml/src/ggml-cpu/vec.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
428428
static const float GELU_COEF_A = 0.044715f;
429429
static const float GELU_QUICK_COEF = -1.702f;
430430
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
431+
static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
431432

432433
inline static float ggml_gelu_f32(float x) {
433434
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static float ggml_gelu_f32(float x) {
440441
}
441442
}
442443
444+
inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
445+
for (int i = 0; i < n; ++i) {
446+
float xi = GGML_FP16_TO_FP32(x[i]);
447+
float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
448+
y[i] = GGML_FP32_TO_FP16(res);
449+
}
450+
}
451+
443452
#ifdef GGML_GELU_FP16
444453
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
445454
uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
463472
}
464473
#endif */
465474

475+
inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
476+
for (int i = 0; i < n; ++i) {
477+
float xi = x[i];
478+
y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
479+
}
480+
}
481+
466482
inline static float ggml_gelu_quick_f32(float x) {
467483
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
468484
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ kernel void kernel_tanh(
874874
constant float GELU_COEF_A = 0.044715f;
875875
constant float GELU_QUICK_COEF = -1.702f;
876876
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
877+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
877878

878879
kernel void kernel_gelu(
879880
device const float * src0,
@@ -915,6 +916,42 @@ kernel void kernel_gelu_quick_4(
915916
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
916917
}
917918

919+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
920+
// ref: https://www.johndcook.com/blog/python_erf/
921+
constant float p_erf = 0.3275911f;
922+
constant float a1_erf = 0.254829592f;
923+
constant float a2_erf = -0.284496736f;
924+
constant float a3_erf = 1.421413741f;
925+
constant float a4_erf = -1.453152027f;
926+
constant float a5_erf = 1.061405429f;
927+
928+
template<typename T>
929+
T erf_approx(T x) {
930+
T sign_x = sign(x);
931+
x = fabs(x);
932+
T t = 1.0f / (1.0f + p_erf * x);
933+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
934+
return sign_x * y;
935+
}
936+
937+
kernel void kernel_gelu_erf(
938+
device const float * src0,
939+
device float * dst,
940+
uint tpig[[thread_position_in_grid]]) {
941+
device const float & x = src0[tpig];
942+
943+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
944+
}
945+
946+
kernel void kernel_gelu_erf_4(
947+
device const float4 * src0,
948+
device float4 * dst,
949+
uint tpig[[thread_position_in_grid]]) {
950+
device const float4 & x = src0[tpig];
951+
952+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
953+
}
954+
918955
kernel void kernel_silu(
919956
device const float * src0,
920957
device float * dst,

0 commit comments

Comments
 (0)