Skip to content

Commit 631b115

Browse files
committed
Add ROCm HIPBLAS compatibility
commit ddcba42 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 17:13:39 2023 -0500 realign text commit a7f8197 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 17:11:10 2023 -0500 small edits commit d4c7ec0 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 17:07:42 2023 -0500 move hipblas definitions to header files commit 39af81c Merge: 5e2f581 bf49a93 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 16:59:12 2023 -0500 Merge branch 'main' into pr/LostRuins/koboldcpp/rocm-patch commit bf49a93 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 16:38:50 2023 -0500 move HIPBLAS definitions into ggml-cuda.h commit 5e2f581 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 21 17:10:08 2023 -0500 Replace readme.md with the version from LostRuins/koboldcpp commit 540f4e0 Merge: 2c3b46f eda663f Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jul 1 14:58:32 2023 -0500 Merge remote-tracking branch 'upstream/concedo' commit 2c3b46f Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 29 18:43:43 2023 -0500 changes to fix build commit c9e1103 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 29 18:20:07 2023 -0500 Update ggml_v2-cuda-legacy.cu for ROCM commit b858fc5 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 29 17:49:39 2023 -0500 changes to work with upstream commit 69a0c25 Merge: 096f0b0 1347d3a Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 29 16:59:06 2023 -0500 Merge remote-tracking branch 'upstream/concedo' commit 096f0b0 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 28 15:27:02 2023 -0500 revert unnecessary hipblas conditionals commit d81e81a Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 28 14:48:23 2023 -0500 Update Makefile hipblas nvcc correction commit 2579ecf Merge: abed427 d2034ce Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 25 17:50:04 2023 -0500 Merge branch 'LostRuins:concedo' into main commit abed427 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jun 24 19:16:30 2023 -0500 reorganize If statements to include proper headers commit 06c3bf0 Merge: ea6d320 8342fe8 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sat Jun 24 16:57:20 2023 -0500 Merge branch 'LostRuins:concedo' into main commit ea6d320 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Fri Jun 23 01:53:28 2023 -0500 Update README.md commit 4d56ad8 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 22 16:19:43 2023 -0500 Update README.md commit 21f9308 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 22 15:42:05 2023 -0500 kquants_iter for hipblas and add gfx803 commit b6ff890 Merge: eb094f0 e6ddb15 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Thu Jun 22 12:42:09 2023 -0500 Merge branch 'LostRuins:concedo' into main commit eb094f0 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 21 23:59:18 2023 -0500 lowvram parameter description commit 3a5dfeb Merge: 665cc11 b1f00fa Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 21 16:53:03 2023 -0500 Merge branch 'LostRuins:concedo' into koboldcpp-rocm commit 665cc11 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Wed Jun 21 01:13:19 2023 -0500 add lowvram parameter commit 222cbbb Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Tue Jun 20 19:03:28 2023 -0500 add additional hipblas conditions for cublas commit e1f9581 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Tue Jun 20 16:51:59 2023 -0500 Add hip def for cuda v2 commit 3bff5c0 Merge: a7e74b3 266d47a Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Tue Jun 20 13:38:06 2023 -0500 Merge branch 'LostRuins:concedo' into koboldcpp-rocm commit a7e74b3 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Mon Jun 19 22:04:18 2023 -0500 Update README.md commit 5e99b3c Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Mon Jun 19 22:03:42 2023 -0500 Update Makefile commit 9190b17 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Mon Jun 19 21:47:10 2023 -0500 Update README.md commit 2780ea2 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 18 15:48:00 2023 -0500 Update Makefile commit 04a3e64 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 18 14:33:39 2023 -0500 remove extra line commit cccbca9 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 18 14:31:17 2023 -0500 attempt adding ROCM hipblas commit a44a1d4 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 18 14:31:01 2023 -0500 attempt adding ROCM hipblas commit b088184 Author: YellowRoseCx <80486540+YellowRoseCx@users.noreply.github.com> Date: Sun Jun 18 14:30:54 2023 -0500 attempt adding ROCM hipblas Original llama.cpp changes started with this branch https://github.com/SlyEcho/llama.cpp/tree/hipblas then were modified
1 parent ef3b8dc commit 631b115

15 files changed

+254
-11
lines changed

CMakeLists.txt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern
4646
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
4747
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
4848
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
49+
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
4950
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
5051

5152

@@ -94,6 +95,38 @@ if (LLAMA_CUBLAS)
9495
endif()
9596
endif()
9697

98+
if (LLAMA_HIPBLAS)
99+
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
100+
101+
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
102+
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
103+
endif()
104+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
105+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
106+
endif()
107+
108+
find_package(hip)
109+
find_package(hipblas)
110+
111+
if (${hipblas_FOUND} AND ${hip_FOUND})
112+
message(STATUS "HIP and hipBLAS found")
113+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
114+
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
115+
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
116+
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
117+
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
118+
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
119+
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::hipblas)
120+
121+
if (LLAMA_STATIC)
122+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
123+
endif()
124+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
125+
else()
126+
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
127+
endif()
128+
endif()
129+
97130
if (LLAMA_ALL_WARNINGS)
98131
if (NOT MSVC)
99132
set(c_flags

Makefile

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,39 @@ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-l
171171
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) $(CUBLAS_FLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
172172
endif # LLAMA_CUBLAS
173173

174+
ifdef LLAMA_HIPBLAS
175+
ROCM_PATH ?= /opt/rocm
176+
CC := $(ROCM_PATH)/llvm/bin/clang
177+
CXX := $(ROCM_PATH)/llvm/bin/clang++
178+
GPU_TARGETS = gfx803 gfx900 gfx906 gfx908 gfx90a gfx1030
179+
LLAMA_CUDA_DMMV_X ?= 64
180+
LLAMA_CUDA_DMMV_Y ?= 2
181+
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
182+
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
183+
LDFLAGS += -L/opt/rocm/lib -Wl,-rpath=$(ROCM_PATH)/lib -lhipblas -lamdhip64
184+
OBJS += ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
185+
186+
ifdef LLAMA_CUDA_KQUANTS_ITER
187+
CXXFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
188+
else
189+
CXXFLAGS += -DK_QUANTS_PER_ITERATION=2
190+
endif
191+
192+
ggml-cuda.o: CXXFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \
193+
-DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \
194+
-DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
195+
# DGGML_CUDA_DMMV_F16 does not currently work with AMD.
196+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
197+
$(CXX) $(CXXFLAGS) -x hip -c -o $@ $<
198+
199+
ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h
200+
$(CXX) $(CXXFLAGS) -x hip -c -o $@ $<
201+
202+
ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h
203+
$(CXX) $(CXXFLAGS) -x hip -c -o $@ $<
204+
endif # LLAMA_HIPBLAS
205+
206+
174207
ifdef LLAMA_METAL
175208
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
176209
CXXFLAGS += -DGGML_USE_METAL

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ For more information, be sure to run the program with the `--help` flag.
7171
- RWKV (all formats except Q4_1_O).
7272
- GPT-NeoX / Pythia / StableLM / Dolly / RedPajama
7373
- MPT models (ggjt v3)
74-
- Basically every single current and historical GGML format that has ever existed should be supported, except for bloomz.cpp due to lack of demand.
74+
- Basically every single current and historical GGML format that has ever existed should be supported, except for bloomz.cpp due to lack of demand.

ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
#include <atomic>
77
#include <assert.h>
88

9+
#ifndef GGML_USE_HIPBLAS
910
#include <cuda_runtime.h>
1011
#include <cublas_v2.h>
1112
#include <cuda_fp16.h>
12-
13+
#endif
1314
#include "ggml-cuda.h"
1415
#include "ggml.h"
1516

ggml-cuda.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,61 @@
11
#pragma once
22

33
#include "ggml.h"
4+
#if defined(GGML_USE_HIPBLAS)
5+
#include <hip/hip_runtime.h>
6+
#include <hipblas/hipblas.h>
7+
#include <hip/hip_fp16.h>
8+
9+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
10+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
11+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
12+
#define CUBLAS_OP_N HIPBLAS_OP_N
13+
#define CUBLAS_OP_T HIPBLAS_OP_T
14+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
15+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
16+
#define CUDA_R_16F HIPBLAS_R_16F
17+
#define CUDA_R_32F HIPBLAS_R_32F
18+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
19+
#define cublasCreate hipblasCreate
20+
#define cublasGemmEx hipblasGemmEx
21+
#define cublasHandle_t hipblasHandle_t
22+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
23+
#define cublasSetStream hipblasSetStream
24+
#define cublasSgemm hipblasSgemm
25+
#define cublasStatus_t hipblasStatus_t
26+
#define cudaDeviceProp hipDeviceProp_t
27+
#define cudaDeviceSynchronize hipDeviceSynchronize
28+
#define cudaError_t hipError_t
29+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
30+
#define cudaEventDisableTiming hipEventDisableTiming
31+
#define cudaEventRecord hipEventRecord
32+
#define cudaEvent_t hipEvent_t
33+
#define cudaFree hipFree
34+
#define cudaFreeHost hipHostFree
35+
#define cudaGetDevice hipGetDevice
36+
#define cudaGetDeviceCount hipGetDeviceCount
37+
#define cudaGetDeviceProperties hipGetDeviceProperties
38+
#define cudaGetErrorString hipGetErrorString
39+
#define cudaGetLastError hipGetLastError
40+
#define cudaMalloc hipMalloc
41+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
42+
#define cudaMemcpy hipMemcpy
43+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
44+
#define cudaMemcpyAsync hipMemcpyAsync
45+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
46+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
47+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
48+
#define cudaMemcpyKind hipMemcpyKind
49+
#define cudaMemset hipMemset
50+
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
51+
#define cudaSetDevice hipSetDevice
52+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
53+
#define cudaStreamNonBlocking hipStreamNonBlocking
54+
#define cudaStreamSynchronize hipStreamSynchronize
55+
#define cudaStreamWaitEvent hipStreamWaitEvent
56+
#define cudaStream_t hipStream_t
57+
#define cudaSuccess hipSuccess
58+
#endif
459

560
#ifdef __cplusplus
661
extern "C" {

ggml.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
230230
#endif
231231
#elif defined(GGML_USE_OPENBLAS)
232232
#include <cblas.h>
233-
#elif defined(GGML_USE_CUBLAS)
233+
#endif
234+
#if defined(GGML_USE_CUBLAS)
234235
#include "ggml-cuda.h"
235-
#elif defined(GGML_USE_CLBLAST)
236+
#endif
237+
#if defined(GGML_USE_CLBLAST)
236238
#include "ggml-opencl.h"
237239
#endif
238240

@@ -19191,4 +19193,4 @@ int ggml_cpu_has_vsx(void) {
1919119193
#endif
1919219194
}
1919319195

19194-
////////////////////////////////////////////////////////////////////////////////
19196+
////////////////////////////////////////////////////////////////////////////////

koboldcpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def onDropdownChange(event):
657657
#load all the vars
658658
args.threads = int(threads_var.get())
659659
args.gpulayers = int(gpu_layers_var.get())
660-
660+
661661
args.stream = (stream.get()==1)
662662
args.smartcontext = (smartcontext.get()==1)
663663
args.launch = (launchbrowser.get()==1)

otherarch/ggml_v2-cuda-legacy.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include <stdio.h>
55
#include <atomic>
66

7+
#ifndef GGML_USE_HIPBLAS
78
#include <cuda_runtime.h>
89
#include <cublas_v2.h>
910
#include <cuda_fp16.h>
11+
#endif
1012

1113
#include "ggml_v2-cuda-legacy.h"
1214
#include "ggml_v2-cuda.h"

otherarch/ggml_v2-cuda-legacy.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,61 @@
11
#include "ggml_v2.h"
22

3+
#if defined(GGML_USE_HIPBLAS)
4+
#include <hip/hip_runtime.h>
5+
#include <hipblas/hipblas.h>
6+
#include <hip/hip_fp16.h>
7+
8+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
9+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
10+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
11+
#define CUBLAS_OP_N HIPBLAS_OP_N
12+
#define CUBLAS_OP_T HIPBLAS_OP_T
13+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
14+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
15+
#define CUDA_R_16F HIPBLAS_R_16F
16+
#define CUDA_R_32F HIPBLAS_R_32F
17+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
18+
#define cublasCreate hipblasCreate
19+
#define cublasGemmEx hipblasGemmEx
20+
#define cublasHandle_t hipblasHandle_t
21+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
22+
#define cublasSetStream hipblasSetStream
23+
#define cublasSgemm hipblasSgemm
24+
#define cublasStatus_t hipblasStatus_t
25+
#define cudaDeviceProp hipDeviceProp_t
26+
#define cudaDeviceSynchronize hipDeviceSynchronize
27+
#define cudaError_t hipError_t
28+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
29+
#define cudaEventDisableTiming hipEventDisableTiming
30+
#define cudaEventRecord hipEventRecord
31+
#define cudaEvent_t hipEvent_t
32+
#define cudaFree hipFree
33+
#define cudaFreeHost hipHostFree
34+
#define cudaGetDevice hipGetDevice
35+
#define cudaGetDeviceCount hipGetDeviceCount
36+
#define cudaGetDeviceProperties hipGetDeviceProperties
37+
#define cudaGetErrorString hipGetErrorString
38+
#define cudaGetLastError hipGetLastError
39+
#define cudaMalloc hipMalloc
40+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
41+
#define cudaMemcpy hipMemcpy
42+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
43+
#define cudaMemcpyAsync hipMemcpyAsync
44+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
45+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
46+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
47+
#define cudaMemcpyKind hipMemcpyKind
48+
#define cudaMemset hipMemset
49+
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
50+
#define cudaSetDevice hipSetDevice
51+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
52+
#define cudaStreamNonBlocking hipStreamNonBlocking
53+
#define cudaStreamSynchronize hipStreamSynchronize
54+
#define cudaStreamWaitEvent hipStreamWaitEvent
55+
#define cudaStream_t hipStream_t
56+
#define cudaSuccess hipSuccess
57+
#endif
58+
359
#ifdef __cplusplus
460
extern "C" {
561
#endif

otherarch/ggml_v2-cuda.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include <stdio.h>
55
#include <atomic>
66

7+
#ifndef GGML_USE_HIPBLAS
78
#include <cuda_runtime.h>
89
#include <cublas_v2.h>
910
#include <cuda_fp16.h>
11+
#endif
1012

1113
#include "ggml_v2-cuda.h"
1214
#include "ggml_v2.h"
@@ -807,4 +809,4 @@ void ggml_v2_cuda_transform_tensor(ggml_v2_tensor * tensor) {
807809

808810
tensor->data = d_Q;
809811
tensor->backend = GGML_V2_BACKEND_CUDA;
810-
}
812+
}

0 commit comments

Comments
 (0)