Skip to content

Commit 4b3f11d

Browse files
committed
Merge branch 'master' of https://github.com/rotemdan/llama.cpp
2 parents 701f0dc + defe215 commit 4b3f11d

38 files changed

+1062
-646
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ jobs:
683683
env:
684684
OPENBLAS_VERSION: 0.3.23
685685
SDE_VERSION: 9.33.0-2024-01-07
686-
VULKAN_VERSION: 1.4.309.0
686+
VULKAN_VERSION: 1.4.313.2
687687

688688
strategy:
689689
matrix:
@@ -736,7 +736,7 @@ jobs:
736736
id: get_vulkan
737737
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
738738
run: |
739-
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
739+
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
740740
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
741741
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
742742
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"

common/json-schema-to-grammar.cpp

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
4141
return result;
4242
}
4343

44-
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
45-
class string_view {
46-
const std::string & _str;
47-
const size_t _start;
48-
const size_t _end;
49-
public:
50-
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
51-
52-
size_t size() const {
53-
return _end - _start;
54-
}
55-
56-
size_t length() const {
57-
return size();
58-
}
59-
60-
operator std::string() const {
61-
return str();
62-
}
63-
64-
std::string str() const {
65-
return _str.substr(_start, _end - _start);
66-
}
67-
68-
string_view substr(size_t pos, size_t len = std::string::npos) const {
69-
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
70-
}
71-
72-
char operator[](size_t pos) const {
73-
auto index = _start + pos;
74-
if (index >= _end) {
75-
throw std::out_of_range("string_view index out of range");
76-
}
77-
return _str[_start + pos];
78-
}
79-
80-
bool operator==(const string_view & other) const {
81-
std::string this_str = *this;
82-
std::string other_str = other;
83-
return this_str == other_str;
84-
}
85-
};
86-
8744
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
8845
auto has_min = min_value != std::numeric_limits<int>::min();
8946
auto has_max = max_value != std::numeric_limits<int>::max();
@@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
11269
}
11370
out << "}";
11471
};
115-
std::function<void(const string_view &, const string_view &)> uniform_range =
116-
[&](const string_view & from, const string_view & to) {
72+
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
73+
[&](const std::string_view & from, const std::string_view & to) {
11774
size_t i = 0;
11875
while (i < from.length() && i < to.length() && from[i] == to[i]) {
11976
i++;
12077
}
12178
if (i > 0) {
122-
out << "\"" << from.substr(0, i).str() << "\"";
79+
out << "\"" << from.substr(0, i) << "\"";
12380
}
12481
if (i < from.length() && i < to.length()) {
12582
if (i > 0) {

convert_hf_to_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2193,7 +2193,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
21932193
name += ".weight"
21942194
if "multi_modal_projector.linear_1" in name:
21952195
# despite the name with number postfix, this is a single fully connected layer
2196-
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)]
2196+
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
21972197
return [(self.map_tensor_name(name), data_torch)]
21982198
return []
21992199

examples/simple-chat/simple-chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
9898
auto generate = [&](const std::string & prompt) {
9999
std::string response;
100100

101-
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
101+
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
102102

103103
// tokenize the prompt
104104
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);

ggml/src/ggml-cuda/common.cuh

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
241241
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
242242
return false;
243243
#else
244-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
244+
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
246+
return true;
247+
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248+
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
249+
return true;
250+
#else
251+
return false;
252+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
253+
} else {
254+
return false;
255+
}
246256
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
247257
}
248258

@@ -252,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
252262
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
253263
}
254264

265+
static bool bf16_mma_hardware_available(const int cc) {
266+
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
267+
}
268+
255269
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
256270
static bool new_mma_available(const int cc) {
257271
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
@@ -362,6 +376,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
362376
#endif // FP16_AVAILABLE
363377
}
364378

379+
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
380+
template<bool norm>
381+
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
382+
const int row = blockIdx.x;
383+
const int col = threadIdx.x;
384+
385+
float sum = 0.0f;
386+
for (int i = col; i < ncols; i += blockDim.x) {
387+
sum += x[row * ncols + i];
388+
}
389+
390+
sum = warp_reduce_sum(sum);
391+
392+
if (col != 0) {
393+
return;
394+
}
395+
396+
dst[row] = norm ? sum / ncols : sum;
397+
}
398+
365399
template<int width = WARP_SIZE>
366400
static __device__ __forceinline__ float warp_reduce_max(float x) {
367401
#pragma unroll

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "ggml-cuda/ssm-scan.cuh"
3838
#include "ggml-cuda/sum.cuh"
3939
#include "ggml-cuda/sumrows.cuh"
40+
#include "ggml-cuda/mean.cuh"
4041
#include "ggml-cuda/tsembd.cuh"
4142
#include "ggml-cuda/unary.cuh"
4243
#include "ggml-cuda/upscale.cuh"
@@ -99,8 +100,7 @@ int ggml_cuda_get_device() {
99100
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
100101
ggml_cuda_set_device(device);
101102
cudaError_t err;
102-
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
103-
{
103+
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
104104
err = cudaMallocManaged(ptr, size);
105105
#if defined(GGML_USE_HIP)
106106
if (err == hipSuccess) {
@@ -118,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
118118
err = cudaMalloc(ptr, size);
119119
}
120120
#endif // defined(GGML_USE_HIP)
121-
}
122-
else
123-
{
121+
} else {
124122
err = cudaMalloc(ptr, size);
125123
}
126124
return err;
@@ -1945,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19451943
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
19461944

19471945
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
1948-
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
1949-
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
1946+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
19501947
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
19511948
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
19521949
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
19531950
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
19541951
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
19551952

1956-
bool any_gpus_with_slow_fp16 = false;
1957-
bool any_gpus_without_fp16_mma = false;
1953+
bool any_gpus_with_slow_fp16 = false;
19581954

19591955
if (split) {
19601956
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1965,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19651961
continue;
19661962
}
19671963

1968-
const int cc = ggml_cuda_info().devices[id].cc;
1969-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1970-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1971-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1964+
const int cc = ggml_cuda_info().devices[id].cc;
1965+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1966+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
1967+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19721968
}
19731969
} else {
1974-
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1975-
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1976-
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
1977-
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
1970+
const int cc = ggml_cuda_info().devices[ctx.device].cc;
1971+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
1972+
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
1973+
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
19781974
}
19791975

19801976
// debug helpers
@@ -1985,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19851981
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
19861982
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
19871983

1988-
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
1984+
if (!split && use_mul_mat_vec) {
19891985
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
19901986
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
19911987
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
@@ -2357,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23572353
case GGML_OP_SUM_ROWS:
23582354
ggml_cuda_op_sum_rows(ctx, dst);
23592355
break;
2356+
case GGML_OP_MEAN:
2357+
ggml_cuda_op_mean(ctx, dst);
2358+
break;
23602359
case GGML_OP_SSM_CONV:
23612360
ggml_cuda_op_ssm_conv(ctx, dst);
23622361
break;
@@ -3260,6 +3259,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32603259
case GGML_OP_POOL_2D:
32613260
case GGML_OP_SUM:
32623261
case GGML_OP_SUM_ROWS:
3262+
case GGML_OP_MEAN:
32633263
case GGML_OP_ARGSORT:
32643264
case GGML_OP_ACC:
32653265
return true;

ggml/src/ggml-cuda/mean.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include "mean.cuh"
2+
3+
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4+
const ggml_tensor * src0 = dst->src[0];
5+
const float * src0_d = (const float *) src0->data;
6+
float * dst_d = (float *) dst->data;
7+
cudaStream_t stream = ctx.stream();
8+
9+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
10+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
11+
GGML_ASSERT(ggml_is_contiguous(src0));
12+
13+
const int64_t ncols = src0->ne[0];
14+
const int64_t nrows = ggml_nrows(src0);
15+
16+
const dim3 block_dims(WARP_SIZE, 1, 1);
17+
const dim3 block_nums(nrows, 1, 1);
18+
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
19+
}

ggml/src/ggml-cuda/mean.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)