diff --git a/.gitignore b/.gitignore index 965d70a..26fe091 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,310 @@ -*.pt -*~ -*.safetensors -*.err +# Created by https://www.toptal.com/developers/gitignore/api/linux,python,macos,vim,cuda,c +# Edit at https://www.toptal.com/developers/gitignore?templates=linux,python,macos,vim,cuda,c + +### C ### +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Linker output +*.ilk +*.map +*.exp + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe *.out -*.json +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Kernel Module Compile Results +*.mod* +*.cmd +.tmp_versions/ +modules.order +Module.symvers +Mkfile.old +dkms.conf + +### CUDA ### +*.i +*.ii +*.gpu +*.ptx +*.cubin +*.fatbin + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### Python ### +# Byte-compiled / optimized / DLL files __pycache__/ -#* +*.py[cod] +*$py.class + +# C extensions + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +# End of https://www.toptal.com/developers/gitignore/api/linux,python,macos,vim,cuda,c + slurm_out/ hfized/ -quiptools/build/ -quiptools/dist -hadamard_cuda/build/ -hadamard_cuda/dist/ report*.nsys-rep report*.sqlite diff --git a/README.md b/README.md index b20453c..cf9f5b6 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,6 @@ This codebase contains code that allows users to quantize and deploy their own m - Clone the repo - Install the requirements via `pip install -r requirements.txt`. You may want to use the official pytorch commands to get the CUDA versions. -- Build and install the matmul CUDA kernels. (`cd quiptools && python setup.py install && cd ../`) ## Quantization diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9f735e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[build-system] +requires = [ + "setuptools >= 61.0", +] +build-backend = "setuptools.build_meta" + +[project] +name = "quip_sharp" +version = "1.0.0" +authors = [ + { name="Albert Tseng", email="albert@cs.cornell.edu" }, + { name="Jerry Chee", email="JerryChee@cs.cornell.edu" }, + { name="Qingyao Sun", email="qs234@cornell.edu" }, + { name="Volodymyr Kuleshov", email="kuleshov@cornell.edu" }, + { name="Christopher De Sa", email="cdesa@cs.cornell.edu" }, +] +description = "QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] +dependencies = [ + "datasets", + "fast-hadamard-transform", + "flash-attn", + "glog", + "lm-eval", + "numpy", + "primefac", + "quiptools-cuda", + "torch >= 2", + "tqdm", + "transformers >= 4.36.0", +] + +[project.urls] +Homepage = "https://github.com/Cornell-RelaxML/quip-sharp" +Issues = "https://github.com/Cornell-RelaxML/quip-sharp/issues" diff --git a/quiptools/.quiptools.cu.swp b/quiptools/.quiptools.cu.swp deleted file mode 100644 index 7c43b35..0000000 Binary files a/quiptools/.quiptools.cu.swp and /dev/null differ diff --git a/quiptools/benchmark_e8p.py b/quiptools/benchmark_e8p.py deleted file mode 100644 index bfe654e..0000000 --- a/quiptools/benchmark_e8p.py +++ /dev/null @@ -1,30 +0,0 @@ -import quiptools_cuda -import torch - - -def benchmark(): - torch.manual_seed(42) - M = 1 - N = 12288 - K = 4096 - - x = torch.randn((M, K), dtype=torch.float32, device="cuda") - Qidxs = torch.randint(1 << 15, (N, K // 8), - dtype=torch.int16, - device="cuda") - codebook = torch.randint(0x7FFFFFFFFFFFFFFF, (256, ), - dtype=torch.int64, - device="cuda") - - # start_event = torch.cuda.Event(enable_timing=True) - # end_event = torch.cuda.Event(enable_timing=True) - # start_event.record() - x = quiptools_cuda.decode_matmul_e8p(x, Qidxs - 0x8000, codebook) - # end_event.record() - # torch.cuda.synchronize() - # elapsed_time_ms = start_event.elapsed_time(end_event) - # print(f"Elapsed: {elapsed_time_ms:.4f}ms") - - -if __name__ == "__main__": - benchmark() diff --git a/quiptools/error.txt b/quiptools/error.txt deleted file mode 100644 index a00df99..0000000 --- a/quiptools/error.txt +++ /dev/null @@ -1,155 +0,0 @@ -running install -running bdist_egg -running egg_info -writing quiptools_cuda.egg-info/PKG-INFO -writing dependency_links to quiptools_cuda.egg-info/dependency_links.txt -writing top-level names to quiptools_cuda.egg-info/top_level.txt -reading manifest file 'quiptools_cuda.egg-info/SOURCES.txt' -writing manifest file 'quiptools_cuda.egg-info/SOURCES.txt' -installing library code to build/bdist.linux-x86_64/egg -running install_lib -running build_ext -building 'quiptools_cuda' extension -/usr/local/cuda/bin/nvcc -I/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/include -I/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/include/TH -I/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/jc3464/anaconda3/envs/smoothquant/include/python3.8 -c quiptools.cu -o build/temp.linux-x86_64-3.8/quiptools.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=quiptools_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_52,code=compute_52 -gencode=arch=compute_52,code=sm_52 -std=c++14 -/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/include/c10/core/SymInt.h(84): warning: integer conversion resulted in a change of sign - -quiptools.cu(17): error: name must be a namespace name - -quiptools.cu(38): error: name followed by "::" must be a class or namespace name - -quiptools.cu(38): error: type name is not allowed - -quiptools.cu(38): error: name followed by "::" must be a class or namespace name - -quiptools.cu(38): error: identifier "a" is undefined - -quiptools.cu(39): error: name followed by "::" must be a class or namespace name - -quiptools.cu(39): error: type name is not allowed - -quiptools.cu(39): error: name followed by "::" must be a class or namespace name - -quiptools.cu(39): error: identifier "b" is undefined - -quiptools.cu(40): error: name followed by "::" must be a class or namespace name - -quiptools.cu(40): error: type name is not allowed - -quiptools.cu(40): error: identifier "c" is undefined - -quiptools.cu(41): error: identifier "fill_fragment" is undefined - -quiptools.cu(50): error: identifier "load_matrix_sync" is undefined - -quiptools.cu(52): error: identifier "mma_sync" is undefined - -quiptools.cu(55): error: name followed by "::" must be a class or namespace name - -quiptools.cu(55): error: identifier "store_matrix_sync" is undefined - -quiptools.cu(110): error: name followed by "::" must be a class or namespace name - -quiptools.cu(110): error: type name is not allowed - -quiptools.cu(110): error: name followed by "::" must be a class or namespace name - -quiptools.cu(110): error: identifier "a" is undefined - -quiptools.cu(111): error: name followed by "::" must be a class or namespace name - -quiptools.cu(111): error: type name is not allowed - -quiptools.cu(111): error: name followed by "::" must be a class or namespace name - -quiptools.cu(111): error: identifier "b" is undefined - -quiptools.cu(112): error: name followed by "::" must be a class or namespace name - -quiptools.cu(112): error: type name is not allowed - -quiptools.cu(112): error: identifier "c0" is undefined - -quiptools.cu(113): error: identifier "fill_fragment" is undefined - -quiptools.cu(115): error: name followed by "::" must be a class or namespace name - -quiptools.cu(115): error: type name is not allowed - -quiptools.cu(115): error: identifier "c1" is undefined - -quiptools.cu(125): error: identifier "load_matrix_sync" is undefined - -quiptools.cu(128): error: identifier "mma_sync" is undefined - -quiptools.cu(134): error: name followed by "::" must be a class or namespace name - -quiptools.cu(134): error: identifier "store_matrix_sync" is undefined - -quiptools.cu(135): error: name followed by "::" must be a class or namespace name - -quiptools.cu(189): error: name followed by "::" must be a class or namespace name - -quiptools.cu(189): error: type name is not allowed - -quiptools.cu(189): error: name followed by "::" must be a class or namespace name - -quiptools.cu(189): error: identifier "a" is undefined - -quiptools.cu(190): error: name followed by "::" must be a class or namespace name - -quiptools.cu(190): error: type name is not allowed - -quiptools.cu(190): error: name followed by "::" must be a class or namespace name - -quiptools.cu(190): error: identifier "b" is undefined - -quiptools.cu(191): error: name followed by "::" must be a class or namespace name - -quiptools.cu(191): error: type name is not allowed - -quiptools.cu(191): error: identifier "c0" is undefined - -quiptools.cu(192): error: identifier "fill_fragment" is undefined - -quiptools.cu(194): error: name followed by "::" must be a class or namespace name - -quiptools.cu(194): error: type name is not allowed - -quiptools.cu(194): error: identifier "c1" is undefined - -quiptools.cu(197): error: name followed by "::" must be a class or namespace name - -quiptools.cu(197): error: type name is not allowed - -quiptools.cu(197): error: identifier "c2" is undefined - -quiptools.cu(200): error: name followed by "::" must be a class or namespace name - -quiptools.cu(200): error: type name is not allowed - -quiptools.cu(200): error: identifier "c3" is undefined - -quiptools.cu(210): error: identifier "load_matrix_sync" is undefined - -quiptools.cu(213): error: identifier "mma_sync" is undefined - -quiptools.cu(225): error: name followed by "::" must be a class or namespace name - -quiptools.cu(225): error: identifier "store_matrix_sync" is undefined - -quiptools.cu(226): error: name followed by "::" must be a class or namespace name - -quiptools.cu(227): error: name followed by "::" must be a class or namespace name - -quiptools.cu(228): error: name followed by "::" must be a class or namespace name - -65 errors detected in the compilation of "quiptools.cu". -/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. - warnings.warn( -/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools. - warnings.warn( -/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/utils/cpp_extension.py:411: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend. - warnings.warn(msg.format('we could not find ninja.')) -/home/jc3464/anaconda3/envs/smoothquant/lib/python3.8/site-packages/torch/utils/cpp_extension.py:813: UserWarning: The detected CUDA version (11.2) has a minor version mismatch with the version that was used to compile PyTorch (11.3). Most likely this shouldn't be a problem. - warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda)) -error: command '/usr/local/cuda/bin/nvcc' failed with exit code 1 diff --git a/quiptools/quiptools.cu b/quiptools/quiptools.cu deleted file mode 100644 index be79fab..0000000 --- a/quiptools/quiptools.cu +++ /dev/null @@ -1,676 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include - -using namespace torch::indexing; -using namespace nvcuda; - -#define FULL_MASK 0xffffffff -#define HALF_MASK 0x0000ffff - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) do { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); } while(false) -#define gpuErrchk(ans) do { gpuAssert((ans), __FILE__, __LINE__); } while (false) - - -__host__ static inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) -{ - if (code != cudaSuccess) - { - fprintf(stderr, "GPUassert[%s:%d]: %s\n", file, line, cudaGetErrorString(code)); - if (abort) exit(code); - } -} - - - -__global__ void cuda_lookupmatmul_d4_k8_kernel( - const c10::Half* __restrict__ X, // k x n - const uint8_t* __restrict__ YIs, // m x (n/4) - const c10::Half* __restrict__ CB, // 256 x 4 - c10::Half* __restrict__ Z, // k x m - size_t K, - size_t M, - size_t N) { - - long m1 = blockIdx.x; - long k1 = blockIdx.y; - - __shared__ c10::Half Y_cache[32*16]; - - wmma::fragment a; // 8 x 16 - wmma::fragment b; // 32 x 16 - wmma::fragment c; // 8 x 32 - fill_fragment(c, __float2half(0.0)); - - for (long jn = 0; jn < N / 16; jn++) { -# pragma unroll 4 - for (long r = 0; r < 4; r++) { - uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); - ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; - } - load_matrix_sync(a, (const __half*)(X + 8*N*k1 + 16*jn), N); - load_matrix_sync(b, (const __half*)Y_cache, 16); - mma_sync(c, a, b, c); - } - - store_matrix_sync((__half*)(&Z[8*M*k1 + 32*m1]), c, M, wmma::mem_row_major); -} - - -void lookupmatmul_d4_k8( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -) { - auto k = X.sizes()[0]; - auto m = YIs.sizes()[0]; - auto n = X.sizes()[1]; - - assert(X.dtype() == torch::kFloat16); - assert(YIs.dtype() == torch::kUInt8); - assert(CB.dtype() == torch::kFloat16); - assert(Z.dtype() == torch::kFloat16); - - assert(Z.sizes()[0] == k); - assert(YIs.sizes()[1] * 4 == n); - assert(Z.sizes()[1] == m); - - assert(k % 8 == 0); // if you want larger k, use k = 16 - assert(m % 32 == 0); - assert(n % 16 == 0); - - const dim3 threads(32); - const dim3 blocks(m/32,k/8); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_lookupmatmul_d4_k8_kernel<<>>( - X.data_ptr(), - YIs.data_ptr(), - CB.data_ptr(), - Z.data_ptr(), - k,m,n - ); -} - - - -__global__ void cuda_lookupmatmul_d4_k16_kernel( - const c10::Half* __restrict__ X, // k x n - const uint8_t* __restrict__ YIs, // m x (n/4) - const c10::Half* __restrict__ CB, // 256 x 4 - c10::Half* __restrict__ Z, // k x m - size_t K, - size_t M, - size_t N) { - - long m1 = blockIdx.x; - long k1 = blockIdx.y; - - __shared__ c10::Half Y_cache[32*16]; - - wmma::fragment a; - wmma::fragment b; - wmma::fragment c0; - fill_fragment(c0, __float2half(0.0)); - - wmma::fragment c1; - fill_fragment(c1, __float2half(0.0)); - - for (long jn = 0; jn < N / 16; jn++) { - for (long r = 0; r < 4; r++) { - uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); - ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; - } - - load_matrix_sync(a, (const __half*)(X + 16*N*k1 + 16*jn), N); - - load_matrix_sync(b, (const __half*)Y_cache, 16); - mma_sync(c0, a, b, c0); - - load_matrix_sync(b, (const __half*)Y_cache + 16*16, 16); - mma_sync(c1, a, b, c1); - } - - store_matrix_sync((__half*)(&Z[16*M*k1 + 32*m1 + 0]), c0, M, wmma::mem_row_major); - store_matrix_sync((__half*)(&Z[16*M*k1 + 32*m1 + 16]), c1, M, wmma::mem_row_major); -} - - -void lookupmatmul_d4_k16( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -) { - auto k = X.sizes()[0]; - auto m = YIs.sizes()[0]; - auto n = X.sizes()[1]; - - assert(X.dtype() == torch::kFloat16); - assert(YIs.dtype() == torch::kUInt8); - assert(CB.dtype() == torch::kFloat16); - assert(Z.dtype() == torch::kFloat16); - - assert(Z.sizes()[0] == k); - assert(YIs.sizes()[1] * 4 == n); - assert(Z.sizes()[1] == m); - - assert(k % 16 == 0); - assert(m % 32 == 0); - assert(n % 16 == 0); - - const dim3 threads(32); - const dim3 blocks(m/32,k/16); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_lookupmatmul_d4_k16_kernel<<>>( - X.data_ptr(), - YIs.data_ptr(), - CB.data_ptr(), - Z.data_ptr(), - k,m,n - ); -} - - -__global__ void cuda_lookupmatmul_d4_k32_kernel( - const c10::Half* __restrict__ X, // k x n - const uint8_t* __restrict__ YIs, // m x (n/4) - const c10::Half* __restrict__ CB, // 256 x 4 - c10::Half* __restrict__ Z, // k x m - size_t K, - size_t M, - size_t N) { - - long m1 = blockIdx.x; - long k1 = blockIdx.y; - - __shared__ c10::Half Y_cache[32*16]; - - wmma::fragment a; - wmma::fragment b; - wmma::fragment c0; - fill_fragment(c0, __float2half(0.0)); - - wmma::fragment c1; - fill_fragment(c1, __float2half(0.0)); - - wmma::fragment c2; - fill_fragment(c2, __float2half(0.0)); - - wmma::fragment c3; - fill_fragment(c3, __float2half(0.0)); - - for (long jn = 0; jn < N / 16; jn++) { - for (long r = 0; r < 4; r++) { - uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); - ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; - } - - load_matrix_sync(a, (const __half*)(X + 16*N*(2*k1+0) + 16*jn), N); - - load_matrix_sync(b, (const __half*)Y_cache, 16); - mma_sync(c0, a, b, c0); - - load_matrix_sync(b, (const __half*)Y_cache + 16*16, 16); - mma_sync(c1, a, b, c1); - - load_matrix_sync(a, (const __half*)(X + 16*N*(2*k1+1) + 16*jn), N); - mma_sync(c3, a, b, c3); - - load_matrix_sync(b, (const __half*)Y_cache, 16); - mma_sync(c2, a, b, c2); - } - - store_matrix_sync((__half*)(&Z[16*M*(2*k1+0) + 32*m1 + 0]), c0, M, wmma::mem_row_major); - store_matrix_sync((__half*)(&Z[16*M*(2*k1+0) + 32*m1 + 16]), c1, M, wmma::mem_row_major); - store_matrix_sync((__half*)(&Z[16*M*(2*k1+1) + 32*m1 + 0]), c2, M, wmma::mem_row_major); - store_matrix_sync((__half*)(&Z[16*M*(2*k1+1) + 32*m1 + 16]), c3, M, wmma::mem_row_major); -} - - -void lookupmatmul_d4_k32( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -) { - auto k = X.sizes()[0]; - auto m = YIs.sizes()[0]; - auto n = X.sizes()[1]; - - assert(X.dtype() == torch::kFloat16); - assert(YIs.dtype() == torch::kUInt8); - assert(CB.dtype() == torch::kFloat16); - assert(Z.dtype() == torch::kFloat16); - - assert(Z.sizes()[0] == k); - assert(YIs.sizes()[1] * 4 == n); - assert(Z.sizes()[1] == m); - - assert(k % 16 == 0); - assert(m % 32 == 0); - assert(n % 16 == 0); - - const dim3 threads(32); - const dim3 blocks(m/32,k/32); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_lookupmatmul_d4_k32_kernel<<>>( - X.data_ptr(), - YIs.data_ptr(), - CB.data_ptr(), - Z.data_ptr(), - k,m,n - ); -} - -#define DECOMPRESS_D4_BLOCK_SIZE 256 - -__global__ void cuda_decompress_d4_origorder_kernel( - const uint8_t* __restrict__ YIs, // m x (n/4) - const c10::Half* __restrict__ CB, // 256 x 4 - c10::Half* __restrict__ Y // m x n -) { - const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x; - - for(long r = 0; r < 4; r++) { - uint8_t yidx = ((uint8_t*)YIs)[i*4 + r]; - ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255]; - } -} - - -void decompress_d4_origorder( - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(CB.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 4 == n); - assert(CB.sizes()[0] == 256); - assert(CB.sizes()[1] == 4); - - const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE); - const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_d4_origorder_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr() - ); -} - - -__global__ void cuda_decompress_d4_kernel( - const uint8_t* __restrict__ YIs, // m x (n/4) - const c10::Half* __restrict__ CB, // 256 x 4 - c10::Half* __restrict__ Y, // m x n - size_t M, - size_t N -) { - const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x; - - const long j = (i % (N/16))*M + (i / (N/16)); - - for(long r = 0; r < 4; r++) { - uint8_t yidx = ((uint8_t*)YIs)[j*4 + r]; - ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255]; - } -} - - -void decompress_d4( - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(CB.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 4 == n); - assert(CB.sizes()[0] == 256); - assert(CB.sizes()[1] == 4); - - const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE); - const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_d4_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr(), - m,n - ); -} - - - - -// This is a terrible kernel, only use this to not call the pytorch version - -#define DECOMPRESS_HI4B1C_BLOCK_SIZE 128 - -__global__ void cuda_decompress_hi4b1c_packed_kernel( - const int32_t* __restrict__ YIs, // m x (n/8) - const c10::Half* __restrict__ CB, // 16 x 1 - c10::Half* __restrict__ Y // m x n -) { - const long i = threadIdx.x + DECOMPRESS_HI4B1C_BLOCK_SIZE * blockIdx.x; - - // 0 2 4 6 1 3 5 7 - uint32_t packed = YIs[i]; - Y[i*8 + 7] = CB[packed & 15]; - Y[i*8 + 5] = CB[(packed >> 4) & 15]; - Y[i*8 + 3] = CB[(packed >> 8) & 15]; - Y[i*8 + 1] = CB[(packed >> 12) & 15]; - Y[i*8 + 6] = CB[(packed >> 16) & 15]; - Y[i*8 + 4] = CB[(packed >> 20) & 15]; - Y[i*8 + 2] = CB[(packed >> 24) & 15]; - Y[i*8 + 0] = CB[(packed >> 28) & 15]; -} - - -void decompress_hi4b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, - torch::Tensor &Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 8 == n); - - assert(CB.sizes()[0] == 16); - assert(CB.sizes()[1] == 1); - - - const dim3 threads(DECOMPRESS_HI4B1C_BLOCK_SIZE); - const dim3 blocks(m*n/(8*DECOMPRESS_HI4B1C_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_hi4b1c_packed_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr() - ); -} - - -// This is a terrible kernel, only use this to not call the pytorch version - -#define DECOMPRESS_HI3B1C_BLOCK_SIZE 128 - -__global__ void cuda_decompress_hi3b1c_packed_kernel( - const int32_t* __restrict__ YIs, // m x (n/8) - const c10::Half* __restrict__ CB, // 16 x 1 - c10::Half* __restrict__ Y // m x n -) { - const long i = threadIdx.x + DECOMPRESS_HI3B1C_BLOCK_SIZE * blockIdx.x; - - // 0 2 4 6 1 3 5 7 - uint32_t packed = YIs[i]; - Y[i*8 + 7] = CB[packed & 15]; - Y[i*8 + 5] = CB[(packed >> 4) & 15]; - Y[i*8 + 3] = CB[(packed >> 8) & 15]; - Y[i*8 + 1] = CB[(packed >> 12) & 15]; - Y[i*8 + 6] = CB[(packed >> 16) & 15]; - Y[i*8 + 4] = CB[(packed >> 20) & 15]; - Y[i*8 + 2] = CB[(packed >> 24) & 15]; - Y[i*8 + 0] = CB[(packed >> 28) & 15]; -} - - -void decompress_hi3b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, - torch::Tensor &Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 8 == n); - - assert(CB.sizes()[0] == 8); - assert(CB.sizes()[1] == 1); - - - const dim3 threads(DECOMPRESS_HI3B1C_BLOCK_SIZE); - const dim3 blocks(m*n/(8*DECOMPRESS_HI3B1C_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_hi3b1c_packed_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr() - ); -} - -// This is a terrible kernel, only use this to not call the pytorch version - -#define DECOMPRESS_HI2B1C_BLOCK_SIZE 128 - -__global__ void cuda_decompress_hi2b1c_packed_kernel( - const int32_t* __restrict__ YIs, // m x (n/8) - const c10::Half* __restrict__ CB, // 16 x 1 - c10::Half* __restrict__ Y // m x n -) { - const long i = threadIdx.x + DECOMPRESS_HI2B1C_BLOCK_SIZE * blockIdx.x; - - // 0 2 4 6 1 3 5 7 - uint32_t packed = YIs[i]; - Y[i*8 + 7] = CB[packed & 15]; - Y[i*8 + 5] = CB[(packed >> 4) & 15]; - Y[i*8 + 3] = CB[(packed >> 8) & 15]; - Y[i*8 + 1] = CB[(packed >> 12) & 15]; - Y[i*8 + 6] = CB[(packed >> 16) & 15]; - Y[i*8 + 4] = CB[(packed >> 20) & 15]; - Y[i*8 + 2] = CB[(packed >> 24) & 15]; - Y[i*8 + 0] = CB[(packed >> 28) & 15]; -} - - -void decompress_hi2b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, - torch::Tensor &Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 8 == n); - - assert(CB.sizes()[0] == 4); - assert(CB.sizes()[1] == 1); - - - const dim3 threads(DECOMPRESS_HI2B1C_BLOCK_SIZE); - const dim3 blocks(m*n/(8*DECOMPRESS_HI2B1C_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_hi2b1c_packed_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr() - ); -} - - - -// This is a terrible kernel, only use this to not call the pytorch version - -#define DECOMPRESS_E81B_BLOCK_SIZE 4 - -__global__ void cuda_decompress_e81b_packed_kernel( - const int64_t* __restrict__ YIs, // m x (n/8) - const c10::Half* __restrict__ CB, // 256 x 8 - c10::Half* __restrict__ Y // m x n -) { - const long i = threadIdx.x + DECOMPRESS_E81B_BLOCK_SIZE * blockIdx.x; - - uint64_t packed = YIs[i]; - -#pragma unroll - for (long j = 0; j < 8; j++) { - uint64_t yidx = packed & 255; - ((uint64_t*)Y)[(i*8 + j)*2] = ((uint64_t*)CB)[yidx*2]; - ((uint64_t*)Y)[(i*8 + j)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; - packed = packed >> 8; - } - -} - -void decompress_e81b_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, - torch::Tensor &Y // m x n -) { - size_t m = Y.sizes()[0]; - size_t n = Y.sizes()[1]; - - assert(YIs.is_contiguous()); - assert(Y.is_contiguous()); - - assert(YIs.sizes()[0] == m); - assert(YIs.sizes()[1] * 64 == n); - - assert(CB.sizes()[0] == 256); - assert(CB.sizes()[1] == 8); - - at::DeviceGuard guard(CB.device()); - const dim3 threads(DECOMPRESS_E81B_BLOCK_SIZE); - const dim3 blocks(m*n/(64*DECOMPRESS_E81B_BLOCK_SIZE)); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_decompress_e81b_packed_kernel<<>>( - YIs.data_ptr(), - CB.data_ptr(), - Y.data_ptr() - ); -} - - - -__global__ void cuda_lookupmatmul_e81b_k8_kernel( - const c10::Half* __restrict__ X, // k x n - const int64_t* __restrict__ YIs, // m x (n/64) - const c10::Half* __restrict__ CB, // 256 x 8 - float* __restrict__ Z, - size_t K, - size_t M, - size_t N) { - - long m1 = blockIdx.x; - long k1 = blockIdx.y; - - __shared__ c10::Half Y_cache0[32*16]; - wmma::fragment a0; // 8 x 16 - wmma::fragment b0; // 32 x 16 - - __shared__ c10::Half Y_cache1[32*16]; - wmma::fragment a1; // 8 x 16 - wmma::fragment b1; // 32 x 16 - - wmma::fragment c; // 8 x 32 - fill_fragment(c, 0.0); - - -#pragma unroll - for (long jn = 0; jn < N / 32; jn++) { - uint32_t packed = ((uint32_t*)YIs)[(m1*32 + threadIdx.x)*(N/32) + jn]; -#pragma unroll - for (long r = 0; r < 2; r++) { - uint32_t yidx = packed & 255; - ((uint64_t*)Y_cache0)[(threadIdx.x*2 + r)*2] = ((uint64_t*)CB)[yidx*2]; - ((uint64_t*)Y_cache0)[(threadIdx.x*2 + r)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; - packed = packed >> 8; - } -#pragma unroll - for (long r = 0; r < 2; r++) { - uint32_t yidx = packed & 255; - ((uint64_t*)Y_cache1)[(threadIdx.x*2 + r)*2] = ((uint64_t*)CB)[yidx*2]; - ((uint64_t*)Y_cache1)[(threadIdx.x*2 + r)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; - packed = packed >> 8; - } - - load_matrix_sync(a0, (const __half*)(X + 8*N*k1 + 32*jn), N); - load_matrix_sync(b0, (const __half*)Y_cache0, 16); - mma_sync(c, a0, b0, c); - - load_matrix_sync(a1, (const __half*)(X + 8*N*k1 + 32*jn + 16), N); - load_matrix_sync(b1, (const __half*)Y_cache1, 16); - mma_sync(c, a1, b1, c); - - } - - store_matrix_sync(&Z[8*M*k1 + 32*m1], c, M, wmma::mem_row_major); -} - - -void lookupmatmul_e81b_k8( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/64) - torch::Tensor CB, // 256 x 8 - torch::Tensor Z // k x m -) { - auto k = X.sizes()[0]; - auto m = YIs.sizes()[0]; - auto n = X.sizes()[1]; - - assert(Z.sizes()[0] == k); - assert(YIs.sizes()[1] * 64 == n); - assert(Z.sizes()[1] == m); - - assert(k <= 8); - assert(m % 32 == 0); - assert(n % 32 == 0); - - at::DeviceGuard guard(CB.device()); - const dim3 threads(32); - const dim3 blocks(m/32, k/8); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cuda_lookupmatmul_e81b_k8_kernel<<>>( - X.data_ptr(), - YIs.data_ptr(), - CB.data_ptr(), - Z.data_ptr(), - k,m,n - ); -} - diff --git a/quiptools/quiptools_cuda.egg-info/PKG-INFO b/quiptools/quiptools_cuda.egg-info/PKG-INFO deleted file mode 100644 index b7b8045..0000000 --- a/quiptools/quiptools_cuda.egg-info/PKG-INFO +++ /dev/null @@ -1,3 +0,0 @@ -Metadata-Version: 2.1 -Name: quiptools-cuda -Version: 0.0.0 diff --git a/quiptools/quiptools_cuda.egg-info/SOURCES.txt b/quiptools/quiptools_cuda.egg-info/SOURCES.txt deleted file mode 100644 index d2a166b..0000000 --- a/quiptools/quiptools_cuda.egg-info/SOURCES.txt +++ /dev/null @@ -1,8 +0,0 @@ -quiptools.cu -quiptools_e8p_gemv.cu -quiptools_wrapper.cpp -setup.py -quiptools_cuda.egg-info/PKG-INFO -quiptools_cuda.egg-info/SOURCES.txt -quiptools_cuda.egg-info/dependency_links.txt -quiptools_cuda.egg-info/top_level.txt \ No newline at end of file diff --git a/quiptools/quiptools_cuda.egg-info/dependency_links.txt b/quiptools/quiptools_cuda.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/quiptools/quiptools_cuda.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/quiptools/quiptools_cuda.egg-info/top_level.txt b/quiptools/quiptools_cuda.egg-info/top_level.txt deleted file mode 100644 index 7eb4156..0000000 --- a/quiptools/quiptools_cuda.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -quiptools_cuda diff --git a/quiptools/quiptools_e8p_gemv.cu b/quiptools/quiptools_e8p_gemv.cu deleted file mode 100644 index de199a6..0000000 --- a/quiptools/quiptools_e8p_gemv.cu +++ /dev/null @@ -1,585 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include - -#include -#include - -using namespace torch::indexing; -using namespace nvcuda; - -#define FULL_MASK 0xffffffff -#define HALF_MASK 0x0000ffff - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) do { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); } while(false) -#define gpuErrchk(ans) do { gpuAssert((ans), __FILE__, __LINE__); } while (false) - - -__host__ static inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) -{ - if (code != cudaSuccess) - { - fprintf(stderr, "GPUassert[%s:%d]: %s\n", file, line, cudaGetErrorString(code)); - if (abort) exit(code); - } -} - -__device__ static inline uint32_t add_as_half2(uint32_t x, uint32_t y) { - uint32_t z; - asm("add.f16x2 %0,%1,%2;" : "=r"(z) : "r"(x), "r"(y)); - return z; -} - - -__device__ static inline uint32_t mask_lop3(uint32_t x, uint32_t m0, uint32_t m1) { - uint32_t y; - asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(y) : "r"(x), "r"(m0), "r"(m1)); - return y; - // return (x & m0) | m1; -} - -#define BASE_OFFSET 0xd080d080 -#define XMASK 0x00f000f0 -#define WMASK 0x50085008 - - -__global__ static void -// __launch_bounds__(1024, 1024) -decode_matvec_e8p_kernel( - float *__restrict__ output, - const uint2 *__restrict__ input, - const uint2 *__restrict__ weights_compressed, - const uint32_t *__restrict__ codebook_abs, - int N, - int K -) { - int warpId = threadIdx.y; - int laneId = threadIdx.x; - - // __shared__ float sum_scratch[16*32]; - - // __shared__ uint32_t codebook_local[256*32]; - // for (int icb = warpId; icb < 256; icb += 32) { - // codebook_local[icb*32 + laneId] = codebook_abs[icb]; - // } - // __syncthreads(); - - __shared__ uint2 shared_weights[1024*2]; - - for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) { - - float z0 = 0.0; - float z1 = 0.0; - float z2 = 0.0; - float z3 = 0.0; - - // int shwo = laneId + 32*warpId; - - // __pipeline_memcpy_async(shared_weights + shwo, weights_compressed + laneId + 32*warpId + 1024*0 + (K >> 1)*iin, 8); - // __pipeline_commit(); - - for (int iik = warpId; iik < (K >> 6); iik += 32) { - // if (iik + 1 < (K >> 11)) { - // __pipeline_memcpy_async(shared_weights + (shwo ^ 1024), weights_compressed + laneId + 32*iik + 1024 + (K >> 1)*iin, 8); - // __pipeline_commit(); - // __pipeline_wait_prior(1); - // shwo = shwo ^ 1024; - // } - // else { - // __pipeline_wait_prior(0); - // } - - // uint2 w_compr = shared_weights[shwo]; // weights_compressed[laneId + 32*warpId + 1024*iik + (K >> 1)*iin]; - uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin]; - uint32_t a = w_compr.x; - uint32_t b = w_compr.y; - - uint32_t s = b; - s = s ^ (s >> 4); - s = s ^ (s >> 8); - s = s ^ (s >> 16); - uint32_t sb = (s & 15); - s = b ^ sb; - sb = sb | (sb << 16); - - uint32_t input_to_warp = ((const uint32_t*)(&input[16*iik]))[laneId]; - uint32_t shifted_laneId = (laneId & 3) << 3; - - /// BLOCK 01 - { - uint32_t x = codebook_abs[(a >> 0) & 255]; - x = x ^ ((s & 0x11111111) * 14); - - uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); - - uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - x = codebook_abs[(a >> 8) & 255]; - x = x ^ ((s & 0x22222222) * 7); - - o = BASE_OFFSET | ((sb & 0x00020002) << 3); - - uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - // uint2 x_in = input[0 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; - // uint32_t x_in0 = x_in.x; - // uint32_t x_in1 = x_in.y; - - uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 0); - uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 1); - - asm( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - " { %0, %1, %2, %3 }," - " { %4, %5, %6, %7 }," - " { %8, %9 }," - " { %0, %1, %2, %3 };" - : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) - : "r"(w00), "r"(w10), "r"(w01), "r"(w11), - "r"(x_in0), "r"(x_in1) - ); - - - // x_in = input[1 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; - // x_in0 = x_in.x; - // x_in1 = x_in.y; - - x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 2); - x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 3); - - asm( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - " { %0, %1, %2, %3 }," - " { %4, %5, %6, %7 }," - " { %8, %9 }," - " { %0, %1, %2, %3 };" - : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) - : "r"(w02), "r"(w12), "r"(w03), "r"(w13), - "r"(x_in0), "r"(x_in1) - ); - } - /// BLOCK 23 - { - uint32_t x = codebook_abs[(a >> 16) & 255]; - s = s >> 2; - x = x ^ ((s & 0x11111111) * 14); - - uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2); - - uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - x = codebook_abs[(a >> 24) & 255]; - x = x ^ ((s & 0x22222222) * 7); - - o = BASE_OFFSET | ((sb & 0x00080008) << 1); - - uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - - // uint2 x_in = input[2 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; - // uint32_t x_in0 = x_in.x; - // uint32_t x_in1 = x_in.y; - - uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 4); - uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 5); - - asm( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - " { %0, %1, %2, %3 }," - " { %4, %5, %6, %7 }," - " { %8, %9 }," - " { %0, %1, %2, %3 };" - : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) - : "r"(w00), "r"(w10), "r"(w01), "r"(w11), - "r"(x_in0), "r"(x_in1) - ); - - - // x_in = input[3 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; - // x_in0 = x_in.x; - // x_in1 = x_in.y; - - x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 6); - x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 7); - - asm( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - " { %0, %1, %2, %3 }," - " { %4, %5, %6, %7 }," - " { %8, %9 }," - " { %0, %1, %2, %3 };" - : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) - : "r"(w02), "r"(w12), "r"(w03), "r"(w13), - "r"(x_in0), "r"(x_in1) - ); - } - } - - // we produced 16 outputs, so only 16 threads - if ((laneId & 1) == 0) { - atomicAdd(output + (iin << 4) + (laneId >> 1), (laneId & 2) ? z2 : z0); - } - - // if ((laneId & 3) == 0) { - // sum_scratch[warpId + ((laneId >> 1) + 0) * 32] = z0; - // sum_scratch[warpId + ((laneId >> 1) + 1) * 32] = z2; - // } - // __syncthreads(); - - // // load and sum - // if (warpId < 16) { - // float acc = sum_scratch[laneId + warpId*32]; - // for (int offset = 16; offset > 0; offset /= 2) { - // acc += __shfl_down_sync(FULL_MASK, acc, offset); - // } - // if (laneId == 0) { - // output[(iin << 4) + warpId] = acc; - // } - // } - } -} - - -__host__ extern torch::Tensor decode_matvec_e8p( - torch::Tensor x, - torch::Tensor weights_compressed, - torch::Tensor codebook_abs -) { - - CHECK_INPUT(x); - CHECK_INPUT(weights_compressed); - CHECK_INPUT(codebook_abs); - - TORCH_CHECK(x.dim() == 1); - TORCH_CHECK(weights_compressed.dim() == 4); - TORCH_CHECK(weights_compressed.size(3) == 4); - TORCH_CHECK(weights_compressed.size(2) == 8); - TORCH_CHECK(codebook_abs.dim() == 1); - TORCH_CHECK(x.scalar_type() == torch::kFloat16); - TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64); - TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32); - TORCH_CHECK(x.size(-1) == weights_compressed.size(1) << 6); - TORCH_CHECK(codebook_abs.size(-1) == 256); - - int64_t N = weights_compressed.size(0) * 16; - int64_t K = x.size(-1); - - TORCH_CHECK(K % 64 == 0, "K is not divisible by 64"); - TORCH_CHECK(N % 16 == 0, "N is not divisible by 16"); - - TORCH_CHECK(K < 65536, "K is not too large"); - TORCH_CHECK(N < 65536, "N is not too large"); - - at::DeviceGuard guard(x.device()); - torch::TensorOptions options = torch::TensorOptions() - .dtype(torch::kFloat32) - .layout(torch::kStrided) - .device(torch::kCUDA) - .requires_grad(false); - torch::Tensor output = torch::zeros(std::vector{N}, options); - - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, x.get_device()); - int64_t grid_size = static_cast(deviceProp.multiProcessorCount); - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - - const dim3 block_size(32,32); - - decode_matvec_e8p_kernel<<>>( - output.data_ptr(), - (const uint2*)x.data_ptr(), - (const uint2*)weights_compressed.data_ptr(), - (const uint32_t*)codebook_abs.data_ptr(), - N, - K); - - gpuErrchk(cudaPeekAtLastError()); - - return output; -} - - - -__global__ static void -test_tc_kernel(float *__restrict__ output) { - int laneId = threadIdx.x; - - uint32_t w0 = (laneId == 0) ? 0x3C003C00 : 0x00000000; - uint32_t w1 = 0x00000000; - uint32_t w2 = 0x00000000; - uint32_t w3 = 0x00000000; - - uint32_t x0 = (laneId == 0) ? 0x3C003C00 : 0x00000000; - uint32_t x1 = 0x00000000; - - float z0 = 0.0; - float z1 = 0.0; - float z2 = 0.0; - float z3 = 0.0; - - asm( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - " { %0, %1, %2, %3 }," - " { %4, %5, %6, %7 }," - " { %8, %9 }," - " { %0, %1, %2, %3 };" - : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) - : "r"(w0), "r"(w1), "r"(w2), "r"(w3), - "r"(x0), "r"(x1) - ); - - output[laneId*4 + 0] = z0; - output[laneId*4 + 1] = z1; - output[laneId*4 + 2] = z2; - output[laneId*4 + 3] = z3; -} - -__host__ extern torch::Tensor test_tc() { - - torch::TensorOptions options = torch::TensorOptions() - .dtype(torch::kFloat32) - .layout(torch::kStrided) - .device(torch::kCUDA) - .requires_grad(false); - torch::Tensor output = torch::zeros(std::vector{32*4}, options); - - test_tc_kernel<<<1, 32>>>(output.data_ptr()); - - gpuErrchk(cudaPeekAtLastError()); - - return output; -} - - - - -__global__ static void -test_codebook_expand_kernel(uint32_t *__restrict__ output, const uint32_t *__restrict__ codebook_abs) { - uint32_t a = threadIdx.x; - uint32_t b = 0; - - for (int i = 0; i < 8; i++) { - b |= (((blockIdx.x >> i) & 1) << (4*i)); - } - - uint32_t s = b; - s = s ^ (s >> 4); - s = s ^ (s >> 8); - s = s ^ (s >> 16); - uint32_t sb = (s & 15); - s = b ^ sb; - sb = sb | (sb << 16); - - uint32_t x = codebook_abs[(a >> 0) & 255]; - x = x ^ ((s & 0x11111111) * 14); - - uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); - - uint32_t w0 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w1 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w2 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w3 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - output[blockIdx.x*256*4 + threadIdx.x*4 + 0] = w0; - output[blockIdx.x*256*4 + threadIdx.x*4 + 1] = w1; - output[blockIdx.x*256*4 + threadIdx.x*4 + 2] = w2; - output[blockIdx.x*256*4 + threadIdx.x*4 + 3] = w3; -} - -__host__ extern torch::Tensor test_codebook_expand(torch::Tensor codebook_abs) { - - torch::TensorOptions options = torch::TensorOptions() - .dtype(torch::kFloat16) - .layout(torch::kStrided) - .device(torch::kCUDA) - .requires_grad(false); - torch::Tensor output = torch::zeros(std::vector{256*256,8}, options); - - test_codebook_expand_kernel<<<256, 256>>>((uint32_t*)output.data_ptr(), (const uint32_t*)codebook_abs.data_ptr()); - - gpuErrchk(cudaPeekAtLastError()); - - return output; -} - - - - -__global__ static void -// __launch_bounds__(1024, 1024) -decompress_packed_e8p_kernel( - uint32_t *__restrict__ output, - const uint2 *__restrict__ weights_compressed, - const uint32_t *__restrict__ codebook_abs, - int N, - int K -) { - int warpId = threadIdx.y; - int laneId = threadIdx.x; - - for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) { - - for (int iik = warpId; iik < (K >> 6); iik += 32) { - uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin]; - uint32_t a = w_compr.x; - uint32_t b = w_compr.y; - - uint32_t s = b; - s = s ^ (s >> 4); - s = s ^ (s >> 8); - s = s ^ (s >> 16); - uint32_t sb = (s & 15); - s = b ^ sb; - sb = sb | (sb << 16); - - /// BLOCK 01 - { - uint32_t x = codebook_abs[(a >> 0) & 255]; - x = x ^ ((s & 0x11111111) * 14); - - uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); - - uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - x = codebook_abs[(a >> 8) & 255]; - x = x ^ ((s & 0x22222222) * 7); - - o = BASE_OFFSET | ((sb & 0x00020002) << 3); - - uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w00; - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w01; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w10; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w11; - - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w02; - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w03; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w12; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w13; - - } - /// BLOCK 23 - { - uint32_t x = codebook_abs[(a >> 16) & 255]; - s = s >> 2; - x = x ^ ((s & 0x11111111) * 14); - - uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2); - - uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - x = codebook_abs[(a >> 24) & 255]; - x = x ^ ((s & 0x22222222) * 7); - - o = BASE_OFFSET | ((sb & 0x00080008) << 1); - - uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); - uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); - uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); - uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); - - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w00; - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w01; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w10; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w11; - - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w02; - output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w03; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w12; - output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w13; - } - } - } -} - - -__host__ extern torch::Tensor decompress_packed_e8p( - torch::Tensor weights_compressed, - torch::Tensor codebook_abs -) { - CHECK_INPUT(weights_compressed); - CHECK_INPUT(codebook_abs); - - TORCH_CHECK(weights_compressed.dim() == 4); - TORCH_CHECK(weights_compressed.size(3) == 4); - TORCH_CHECK(weights_compressed.size(2) == 8); - TORCH_CHECK(codebook_abs.dim() == 1); - TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64); - TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32); - TORCH_CHECK(codebook_abs.size(-1) == 256); - - int64_t N = weights_compressed.size(0) * 16; - int64_t K = weights_compressed.size(1) << 6; - - TORCH_CHECK(K % 64 == 0, "K is not divisible by 64"); - TORCH_CHECK(N % 16 == 0, "N is not divisible by 16"); - - TORCH_CHECK(K < 65536, "K is not too large"); - TORCH_CHECK(N < 65536, "N is not too large"); - - at::DeviceGuard guard(codebook_abs.device()); - torch::TensorOptions options = torch::TensorOptions() - .dtype(torch::kFloat16) - .layout(torch::kStrided) - .device(torch::kCUDA) - .requires_grad(false); - torch::Tensor output = torch::zeros(std::vector{N,K}, options); - - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, weights_compressed.get_device()); - int64_t grid_size = static_cast(deviceProp.multiProcessorCount); - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - - const dim3 block_size(32,32); - - decompress_packed_e8p_kernel<<>>( - (uint32_t*)output.data_ptr(), - (const uint2*)weights_compressed.data_ptr(), - (const uint32_t*)codebook_abs.data_ptr(), - N, - K); - - gpuErrchk(cudaPeekAtLastError()); - - return output; -} \ No newline at end of file diff --git a/quiptools/quiptools_wrapper.cpp b/quiptools/quiptools_wrapper.cpp deleted file mode 100644 index 3061be7..0000000 --- a/quiptools/quiptools_wrapper.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include - -#include -#include - -void lookupmatmul_d4_k8( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -); - -void lookupmatmul_d4_k16( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -); - -void lookupmatmul_d4_k32( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -); - -void decompress_d4( - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Y // m x n -); - -void decompress_d4_origorder( - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Y // m x n -); - -torch::Tensor decompress_packed_e8p( - torch::Tensor weights_compressed, // m x (n/8) - torch::Tensor codebook_abs // 256 x 8 -); - -torch::Tensor decode_matvec_e8p( - torch::Tensor x, - torch::Tensor weights_compressed, - torch::Tensor codebook_abs -); - -void decompress_hi4b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, // 16 x 1 - torch::Tensor &Y // m x n -); - -void decompress_hi3b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, // 16 x 1 - torch::Tensor &Y // m x n -); - -void decompress_hi2b1c_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, // 16 x 1 - torch::Tensor &Y // m x n -); - -void decompress_e81b_packed( - torch::Tensor YIs, // m x (n/8) - torch::Tensor CB, // 256 x 8 - torch::Tensor &Y // m x n -); - -void lookupmatmul_e81b_k8( - torch::Tensor X, // k x n - torch::Tensor YIs, // m x (n/4) - torch::Tensor CB, // 256 x 4 - torch::Tensor Z // k x m -); - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("lookupmatmul_d4_k8", &lookupmatmul_d4_k8, "lookupmatmul_d4_k8"); - m.def("lookupmatmul_d4_k16", &lookupmatmul_d4_k16, "lookupmatmul_d4_k16"); - m.def("lookupmatmul_d4_k32", &lookupmatmul_d4_k32, "lookupmatmul_d4_k32"); - m.def("decompress_d4", &decompress_d4, "decompress_d4"); - m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder"); - m.def("decompress_packed_e8p", &decompress_packed_e8p, "decompress_packed_e8p"); - m.def("decode_matvec_e8p", &decode_matvec_e8p, "decode_matvec_e8p"); - m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed"); - m.def("decompress_hi3b1c_packed", &decompress_hi3b1c_packed, "decompress_hi3b1c_packed"); - m.def("decompress_hi2b1c_packed", &decompress_hi2b1c_packed, "decompress_hi2b1c_packed"); - m.def("decompress_e81b_packed", &decompress_e81b_packed, "decompress_e81b_packed"); - m.def("lookupmatmul_e81b_k8", &lookupmatmul_e81b_k8, "lookupmatmul_e81b_k8"); -} - diff --git a/quiptools/setup.py b/quiptools/setup.py deleted file mode 100644 index 3a98e85..0000000 --- a/quiptools/setup.py +++ /dev/null @@ -1,15 +0,0 @@ -from setuptools import Extension, setup -from torch.utils import cpp_extension - -setup( - name='quiptools_cuda', - ext_modules=[ - cpp_extension.CUDAExtension( - 'quiptools_cuda', - ['quiptools_wrapper.cpp', 'quiptools.cu', 'quiptools_e8p_gemv.cu'], - extra_compile_args={ - 'cxx': ['-g', '-lineinfo'], - 'nvcc': ['-O2', '-g', '-Xcompiler', '-rdynamic', '-lineinfo'] - }) - ], - cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/quiptools/test_d4.py b/quiptools/test_d4.py deleted file mode 100644 index 8554ca4..0000000 --- a/quiptools/test_d4.py +++ /dev/null @@ -1,87 +0,0 @@ -import time - -import quiptools_cuda -import torch - -k = 32 * 32 -m = 8192 * 2 -n = 8192 // 2 - -x = torch.randn(k, n, dtype=torch.float16, device="cuda") -z = torch.zeros(k, m, dtype=torch.float16, device="cuda") -cb = torch.randn(256, 4, dtype=torch.float16, device="cuda") -yidxs = torch.randint(256, (m, n // 4), device="cuda").to(torch.uint8) - -yidxs_reordered = yidxs.view(m, n // (4 * 4), - 4).permute(1, 0, 2).reshape(m, - n // 4).contiguous() - -y1 = torch.zeros(m, n, dtype=torch.float16, device="cuda") - -# yidxs_reordered = yidxs.view(m//32,32,n//(4*4),4).permute(0,2,1,3).reshape(m,n//4).contiguous() - -# yidxs_reordered_k16 = yidxs.view(m//16,16,n//(4*4),4).permute(0,2,1,3).reshape(m,n//4).contiguous() - -torch.cuda.synchronize() -start = time.time() -y = cb[yidxs.view(-1).to(torch.int32), :].view(m, n) -z0 = x @ y.t() -torch.cuda.synchronize() -end = time.time() -print(f"elapsed for pure torch: {end - start}") - -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.decompress_d4_origorder(yidxs, cb, y1) -torch.cuda.synchronize() -end = time.time() -print(f"elapsed for orig decompress_d4: {end - start}") - -assert ((y1 == y).all()) - -y1.zero_() -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.decompress_d4(yidxs_reordered, cb, y1) -z1 = x @ y1.t() -torch.cuda.synchronize() -end = time.time() -print(f"elapsed for decompress_d4 and multiply: {end - start}") - -assert ((y1 == y).all()) - -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.lookupmatmul_d4_k8(x, yidxs_reordered, cb, z) -torch.cuda.synchronize() -end = time.time() -print(f" elapsed for k8 cuda: {end - start}") - -lookupmatmul_d4_k8_err = ( - (z.to(torch.float32) - z0.to(torch.float32)).square().sum() / - ((z0.to(torch.float32)).square().sum() + 1e-10)) -print(f"lookupmatmul_d4_k8 error: {lookupmatmul_d4_k8_err}") - -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.lookupmatmul_d4_k16(x, yidxs_reordered, cb, z) -torch.cuda.synchronize() -end = time.time() -print(f" elapsed for k16 cuda: {end - start}") - -lookupmatmul_d4_k16_err = ( - (z.to(torch.float32) - z0.to(torch.float32)).square().sum() / - ((z0.to(torch.float32)).square().sum() + 1e-10)) -print(f"lookupmatmul_d4_k16 error: {lookupmatmul_d4_k16_err}") - -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.lookupmatmul_d4_k32(x, yidxs_reordered, cb, z) -torch.cuda.synchronize() -end = time.time() -print(f" elapsed for k32 cuda: {end - start}") - -lookupmatmul_d4_k32_err = ( - (z.to(torch.float32) - z0.to(torch.float32)).square().sum() / - ((z0.to(torch.float32)).square().sum() + 1e-10)) -print(f"lookupmatmul_d4_k32 error: {lookupmatmul_d4_k32_err}") diff --git a/quiptools/test_e8p.py b/quiptools/test_e8p.py deleted file mode 100644 index f79cfe1..0000000 --- a/quiptools/test_e8p.py +++ /dev/null @@ -1,32 +0,0 @@ -import time - -import quiptools_cuda -import torch - -torch.manual_seed(0) -m = 8192 * 2 -n = 8192 // 2 - -cb = torch.randn(256, 8, dtype=torch.float16, device="cuda") -cb_even = cb.sum(dim=-1) > 0 -yidxs = torch.randint(2**16, (m, n // 8), device="cuda").to(torch.int16) -y1 = torch.zeros(m, n, dtype=torch.float16, device="cuda") -''' -torch.cuda.synchronize() -start = time.time() -y = cb[yidxs.view(-1).to(torch.int32)+2**15,:].view(m,n) -torch.cuda.synchronize() -end = time.time() -print(f"elapsed for pure torch: {end - start}") -''' - -torch.cuda.synchronize() -start = time.time() -quiptools_cuda.decompress_e8p_origorder(yidxs, cb, cb_even, y1) -torch.cuda.synchronize() -end = time.time() -print(f"elapsed for orig decompress_e8p: {end - start}") - -print(y1) - -#assert((y1 == y).all()) diff --git a/requirements.txt b/requirements.txt index ba33db1..da037f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -109,3 +109,4 @@ xxhash==3.4.1 yarl==1.9.2 zstandard==0.22.0 fast-hadamard-transform +quiptools-cuda diff --git a/lib/algo/__init__.py b/src/quip_sharp/__init__.py similarity index 100% rename from lib/algo/__init__.py rename to src/quip_sharp/__init__.py diff --git a/eval/eval_ppl.py b/src/quip_sharp/eval/eval_ppl.py similarity index 95% rename from eval/eval_ppl.py rename to src/quip_sharp/eval/eval_ppl.py index 71687c9..061fdff 100644 --- a/eval/eval_ppl.py +++ b/src/quip_sharp/eval/eval_ppl.py @@ -9,8 +9,8 @@ import torch from tqdm import tqdm -from lib.utils import gptq_data_utils -from lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.lib.utils import gptq_data_utils +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path torch.set_grad_enabled(False) diff --git a/eval/eval_speed.py b/src/quip_sharp/eval/eval_speed.py similarity index 95% rename from eval/eval_speed.py rename to src/quip_sharp/eval/eval_speed.py index c61346b..7e73725 100644 --- a/eval/eval_speed.py +++ b/src/quip_sharp/eval/eval_speed.py @@ -7,7 +7,7 @@ from torch.profiler import ProfilerActivity, profile, record_function from transformers import AutoTokenizer -from lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path torch.set_grad_enabled(False) diff --git a/eval/eval_zeroshot.py b/src/quip_sharp/eval/eval_zeroshot.py similarity index 94% rename from eval/eval_zeroshot.py rename to src/quip_sharp/eval/eval_zeroshot.py index de8cbcb..672c5d8 100644 --- a/eval/eval_zeroshot.py +++ b/src/quip_sharp/eval/eval_zeroshot.py @@ -9,8 +9,8 @@ from lm_eval import evaluator, tasks from transformers import AutoTokenizer -from lib.utils import LMEvalAdaptor -from lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.lib.utils import LMEvalAdaptor +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) diff --git a/eval/interactive_gen.py b/src/quip_sharp/eval/interactive_gen.py similarity index 95% rename from eval/interactive_gen.py rename to src/quip_sharp/eval/interactive_gen.py index f6de71e..3e65d26 100644 --- a/eval/interactive_gen.py +++ b/src/quip_sharp/eval/interactive_gen.py @@ -7,7 +7,7 @@ from torch.profiler import ProfilerActivity, profile, record_function from transformers import AutoTokenizer -from lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path torch.set_grad_enabled(False) diff --git a/lib/__init__.py b/src/quip_sharp/lib/__init__.py similarity index 100% rename from lib/__init__.py rename to src/quip_sharp/lib/__init__.py diff --git a/src/quip_sharp/lib/algo/__init__.py b/src/quip_sharp/lib/algo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/algo/finetune.py b/src/quip_sharp/lib/algo/finetune.py similarity index 99% rename from lib/algo/finetune.py rename to src/quip_sharp/lib/algo/finetune.py index e5bcb36..f98c864 100644 --- a/lib/algo/finetune.py +++ b/src/quip_sharp/lib/algo/finetune.py @@ -8,8 +8,8 @@ import torch from torch import nn -from lib import codebook, utils -from lib.linear import * +from quip_sharp.lib import codebook, utils +from quip_sharp.lib.linear import * from . import quip diff --git a/lib/algo/quip.py b/src/quip_sharp/lib/algo/quip.py similarity index 99% rename from lib/algo/quip.py rename to src/quip_sharp/lib/algo/quip.py index 7ceb211..9f1f0ee 100644 --- a/lib/algo/quip.py +++ b/src/quip_sharp/lib/algo/quip.py @@ -5,7 +5,7 @@ import torch from tqdm import tqdm -from lib import utils +from quip_sharp.lib import utils def RHT_H(H, SU): diff --git a/lib/codebook/__init__.py b/src/quip_sharp/lib/codebook/__init__.py similarity index 100% rename from lib/codebook/__init__.py rename to src/quip_sharp/lib/codebook/__init__.py diff --git a/lib/codebook/latticee8_padded12.py b/src/quip_sharp/lib/codebook/latticee8_padded12.py similarity index 96% rename from lib/codebook/latticee8_padded12.py rename to src/quip_sharp/lib/codebook/latticee8_padded12.py index a0720f1..78e1cca 100644 --- a/lib/codebook/latticee8_padded12.py +++ b/src/quip_sharp/lib/codebook/latticee8_padded12.py @@ -8,15 +8,11 @@ which makes 2^16 entries. This corresponds to a subset of E8 + 1/4 """ -import itertools -import math -from functools import cache - -import quiptools_cuda +import quiptools import torch from torch import nn -from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda +from quip_sharp.lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda _E8P_CODESZ = 8 @@ -220,7 +216,7 @@ def maybe_pack_idxs(self, idxs): def by_idxs(self, idxs, **kwargs): m, n = idxs.shape - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( idxs.view(m // 16, n // 2, 8, 4), self.grid_packed_abs) return W_decompressed @@ -240,7 +236,7 @@ def cache_WH(self, n, m, Qidxs_list, had_left, had_right, K_left, K_right, **kwargs): self.W = matmul_hadU_cuda( matmul_hadU_cuda( - quiptools_cuda.decompress_packed_e8p( + quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs).float() / self.scale, had_left, K_left).T, @@ -280,12 +276,12 @@ def forward(self, ABx = Bx @ A.t().to(torch.float32) if x.size(0) == 1: - x = quiptools_cuda.decode_matvec_e8p( + x = quiptools.decode_matvec_e8p( x[0].to(torch.float16), Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs).to(torch.float32) else: - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs) x = (x.to(torch.float16) @ W_decompressed.T).to(torch.float32) diff --git a/lib/codebook/latticee8_padded12_rvq3bit.py b/src/quip_sharp/lib/codebook/latticee8_padded12_rvq3bit.py similarity index 95% rename from lib/codebook/latticee8_padded12_rvq3bit.py rename to src/quip_sharp/lib/codebook/latticee8_padded12_rvq3bit.py index 0890ab5..1b3330f 100644 --- a/lib/codebook/latticee8_padded12_rvq3bit.py +++ b/src/quip_sharp/lib/codebook/latticee8_padded12_rvq3bit.py @@ -2,15 +2,11 @@ E8 3 bit. Made from 2 bit E8P + 1 bit E8 with RVQ. """ -import itertools -import math -from functools import cache - -import quiptools_cuda +import quiptools import torch from torch import nn -from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda +from quip_sharp.lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda _E8P_CODESZ = 8 @@ -296,7 +292,7 @@ def by_idxs(self, idxs, **kwargs): init_idxs = idxs[:, :split].contiguous() resid_idxs = idxs[:, split:].contiguous() m, n = init_idxs.shape - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( init_idxs.view(m // 16, n // 2, 8, 4), self.grid_packed_abs) W_resid_decompressed = torch.zeros( @@ -306,7 +302,7 @@ def by_idxs(self, idxs, **kwargs): dtype=torch.float16, ) - quiptools_cuda.decompress_e81b_packed(resid_idxs, + quiptools.decompress_e81b_packed(resid_idxs, self.e81b_grid.to(torch.float16), W_resid_decompressed) return W_decompressed + W_resid_decompressed / self.opt_resid_scale @@ -334,7 +330,7 @@ def cache_WH(self, K_right, resid_scale_override=-1, **kwargs): - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs) @@ -343,7 +339,7 @@ def cache_WH(self, device=Qidxs_list[1].device, dtype=torch.float16) - quiptools_cuda.decompress_e81b_packed(Qidxs_list[1], + quiptools.decompress_e81b_packed(Qidxs_list[1], self.codebook.e81b_grid, W_resid_decompressed) resid_scale = resid_scale_override if resid_scale_override > 0 else \ @@ -408,16 +404,16 @@ def forward(self, m, dtype=torch.float32, device=x_padded.device) - quiptools_cuda.lookupmatmul_e81b_k8(x_padded / resid_scale, + quiptools.lookupmatmul_e81b_k8(x_padded / resid_scale, Qidxs_list[1], self.codebook.e81b_grid, z) - x = quiptools_cuda.decode_matvec_e8p( + x = quiptools.decode_matvec_e8p( x16[0], Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs).to(torch.float32) + z[0] else: - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs) @@ -427,7 +423,7 @@ def forward(self, device=Qidxs_list[1].device, dtype=torch.float16) - quiptools_cuda.decompress_e81b_packed(Qidxs_list[1], + quiptools.decompress_e81b_packed(Qidxs_list[1], self.codebook.e81b_grid, W_resid_decompressed) diff --git a/lib/codebook/latticee8_padded12_rvq4bit.py b/src/quip_sharp/lib/codebook/latticee8_padded12_rvq4bit.py similarity index 94% rename from lib/codebook/latticee8_padded12_rvq4bit.py rename to src/quip_sharp/lib/codebook/latticee8_padded12_rvq4bit.py index 3711461..0fa1c07 100644 --- a/lib/codebook/latticee8_padded12_rvq4bit.py +++ b/src/quip_sharp/lib/codebook/latticee8_padded12_rvq4bit.py @@ -2,15 +2,11 @@ E8 4 bit. 2 2 bit E8P codebooks with RVQ. """ -import itertools -import math -from functools import cache - -import quiptools_cuda +import quiptools import torch from torch import nn -from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda +from quip_sharp.lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda _E8P_CODESZ = 8 @@ -251,9 +247,9 @@ def by_idxs(self, idxs, **kwargs): init_idxs = idxs[:, :idxs.shape[-1] // 2].contiguous() resid_idxs = idxs[:, idxs.shape[-1] // 2:].contiguous() m, n = init_idxs.shape - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( init_idxs.view(m // 16, n // 2, 8, 4), - self.grid_packed_abs) + quiptools_cuda.decompress_packed_e8p( + self.grid_packed_abs) + quiptools.decompress_packed_e8p( resid_idxs.view(m // 16, n // 2, 8, 4), self.grid_packed_abs) / self.opt_resid_scale return W_decompressed @@ -283,9 +279,9 @@ def cache_WH(self, **kwargs): resid_scale = resid_scale_override if resid_scale_override > 0 else \ self.codebook.opt_resid_scale - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook. - grid_packed_abs).float() + quiptools_cuda.decompress_packed_e8p( + grid_packed_abs).float() + quiptools.decompress_packed_e8p( Qidxs_list[1].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs).float() / resid_scale self.W = matmul_hadU_cuda( @@ -335,17 +331,17 @@ def forward(self, self.codebook.opt_resid_scale if x.size(0) == 1: x16 = x[0].to(torch.float16) - x = (quiptools_cuda.decode_matvec_e8p( + x = (quiptools.decode_matvec_e8p( x16, Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs) + - quiptools_cuda.decode_matvec_e8p( + quiptools.decode_matvec_e8p( x16 / resid_scale, Qidxs_list[1].view( m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs)).to(torch.float32) else: - W_decompressed = quiptools_cuda.decompress_packed_e8p( + W_decompressed = quiptools.decompress_packed_e8p( Qidxs_list[0].view(m // 16, n // 64, 8, 4), self.codebook. - grid_packed_abs) + quiptools_cuda.decompress_packed_e8p( + grid_packed_abs) + quiptools.decompress_packed_e8p( Qidxs_list[1].view(m // 16, n // 64, 8, 4), self.codebook.grid_packed_abs) / resid_scale x = (x.to(torch.float16) @ W_decompressed.T).to(torch.float32) diff --git a/lib/linear/__init__.py b/src/quip_sharp/lib/linear/__init__.py similarity index 100% rename from lib/linear/__init__.py rename to src/quip_sharp/lib/linear/__init__.py diff --git a/lib/linear/fused_linear.py b/src/quip_sharp/lib/linear/fused_linear.py similarity index 100% rename from lib/linear/fused_linear.py rename to src/quip_sharp/lib/linear/fused_linear.py diff --git a/lib/linear/fused_quantized_linear.py b/src/quip_sharp/lib/linear/fused_quantized_linear.py similarity index 85% rename from lib/linear/fused_quantized_linear.py rename to src/quip_sharp/lib/linear/fused_quantized_linear.py index f767187..dbf9e5b 100644 --- a/lib/linear/fused_quantized_linear.py +++ b/src/quip_sharp/lib/linear/fused_quantized_linear.py @@ -1,11 +1,4 @@ -import time - -import quiptools_cuda import torch -import torch.nn as nn - -from lib import codebook -from lib.utils import dtype_from_str, get_hadK from .quantized_linear import QuantizedLinear diff --git a/lib/linear/quantized_linear.py b/src/quip_sharp/lib/linear/quantized_linear.py similarity index 98% rename from lib/linear/quantized_linear.py rename to src/quip_sharp/lib/linear/quantized_linear.py index aee994f..e553df4 100644 --- a/lib/linear/quantized_linear.py +++ b/src/quip_sharp/lib/linear/quantized_linear.py @@ -1,11 +1,8 @@ -import time - -import quiptools_cuda import torch import torch.nn as nn -from lib import codebook -from lib.utils import clean, dtype_from_str, get_hadK +from quip_sharp.lib import codebook +from quip_sharp.lib.utils import clean, dtype_from_str, get_hadK class QuantizedLinear(nn.Module): diff --git a/lib/utils/__init__.py b/src/quip_sharp/lib/utils/__init__.py similarity index 100% rename from lib/utils/__init__.py rename to src/quip_sharp/lib/utils/__init__.py diff --git a/lib/utils/data_utils.py b/src/quip_sharp/lib/utils/data_utils.py similarity index 98% rename from lib/utils/data_utils.py rename to src/quip_sharp/lib/utils/data_utils.py index 02d1395..d27a777 100644 --- a/lib/utils/data_utils.py +++ b/src/quip_sharp/lib/utils/data_utils.py @@ -5,7 +5,7 @@ from datasets import load_dataset from torch.utils.data import DataLoader, Dataset -from lib import codebook +from quip_sharp.lib import codebook from .matmul_had import matmul_hadU @@ -57,7 +57,8 @@ def wrap_tokenizer(tokenizer, x, ctx_size): def sample_rp1t(tokenizer, size=128, ctx_size=2048, nproc=1): dataset = load_dataset('togethercomputer/RedPajama-Data-1T-Sample', - split='train') + split='train', + trust_remote_code=True) devset = torch.zeros((size, ctx_size), dtype=torch.int64) saved = 0 if nproc > 1: diff --git a/lib/utils/finetune.py b/src/quip_sharp/lib/utils/finetune.py similarity index 100% rename from lib/utils/finetune.py rename to src/quip_sharp/lib/utils/finetune.py diff --git a/lib/utils/gptq_data_utils.py b/src/quip_sharp/lib/utils/gptq_data_utils.py similarity index 100% rename from lib/utils/gptq_data_utils.py rename to src/quip_sharp/lib/utils/gptq_data_utils.py diff --git a/lib/utils/graph_wrapper.py b/src/quip_sharp/lib/utils/graph_wrapper.py similarity index 100% rename from lib/utils/graph_wrapper.py rename to src/quip_sharp/lib/utils/graph_wrapper.py diff --git a/lib/utils/lm_eval_adaptor.py b/src/quip_sharp/lib/utils/lm_eval_adaptor.py similarity index 98% rename from lib/utils/lm_eval_adaptor.py rename to src/quip_sharp/lib/utils/lm_eval_adaptor.py index 6edd12a..5e7d802 100644 --- a/lib/utils/lm_eval_adaptor.py +++ b/src/quip_sharp/lib/utils/lm_eval_adaptor.py @@ -1,13 +1,11 @@ -import fnmatch - import torch import transformers -from lm_eval.base import BaseLM +from lm_eval.api.model import LM # adapted from https://github.com/mit-han-lab/llm-awq/tree/main -class LMEvalAdaptor(BaseLM): +class LMEvalAdaptor(LM): def __init__(self, model_name, diff --git a/lib/utils/math_utils.py b/src/quip_sharp/lib/utils/math_utils.py similarity index 100% rename from lib/utils/math_utils.py rename to src/quip_sharp/lib/utils/math_utils.py diff --git a/lib/utils/matmul_had.py b/src/quip_sharp/lib/utils/matmul_had.py similarity index 99% rename from lib/utils/matmul_had.py rename to src/quip_sharp/lib/utils/matmul_had.py index 13c52a0..de0f0df 100644 --- a/lib/utils/matmul_had.py +++ b/src/quip_sharp/lib/utils/matmul_had.py @@ -1,7 +1,7 @@ import fast_hadamard_transform import torch -from lib import utils +from quip_sharp.lib import utils def get_hadK(n, transpose=False): diff --git a/lib/utils/matmul_kron.py b/src/quip_sharp/lib/utils/matmul_kron.py similarity index 100% rename from lib/utils/matmul_kron.py rename to src/quip_sharp/lib/utils/matmul_kron.py diff --git a/lib/utils/misc.py b/src/quip_sharp/lib/utils/misc.py similarity index 100% rename from lib/utils/misc.py rename to src/quip_sharp/lib/utils/misc.py diff --git a/lib/utils/model_version.py b/src/quip_sharp/lib/utils/model_version.py similarity index 100% rename from lib/utils/model_version.py rename to src/quip_sharp/lib/utils/model_version.py diff --git a/lib/utils/shard_model.py b/src/quip_sharp/lib/utils/shard_model.py similarity index 100% rename from lib/utils/shard_model.py rename to src/quip_sharp/lib/utils/shard_model.py diff --git a/lib/utils/unsafe_import.py b/src/quip_sharp/lib/utils/unsafe_import.py similarity index 91% rename from lib/utils/unsafe_import.py rename to src/quip_sharp/lib/utils/unsafe_import.py index 3b6449e..495c1b3 100644 --- a/lib/utils/unsafe_import.py +++ b/src/quip_sharp/lib/utils/unsafe_import.py @@ -5,7 +5,7 @@ import transformers -from model.llama import LlamaForCausalLM +from quip_sharp.model.llama import LlamaForCausalLM from . import graph_wrapper @@ -21,7 +21,7 @@ def maybe_wrap(use_cuda_graph): # AutoConfig fails to read name_or_path correctly bad_config = transformers.AutoConfig.from_pretrained(path) - is_quantized = hasattr(bad_config, 'quip_params') + is_quantized = hasattr(bad_config, 'quantization_config') model_type = bad_config.model_type if is_quantized: if model_type == 'llama': diff --git a/model/llama.py b/src/quip_sharp/model/llama.py similarity index 98% rename from model/llama.py rename to src/quip_sharp/model/llama.py index 135ff3d..d50211b 100644 --- a/model/llama.py +++ b/src/quip_sharp/model/llama.py @@ -61,9 +61,9 @@ _prepare_4d_causal_attention_mask = torch.fx.wrap( _prepare_4d_causal_attention_mask) -from lib.linear.fused_quantized_linear import FusedQuantizedLinear -from lib.linear.quantized_linear import QuantizedLinear -from lib.utils import check_model_version +from quip_sharp.lib.linear.fused_quantized_linear import FusedQuantizedLinear +from quip_sharp.lib.linear.quantized_linear import QuantizedLinear +from quip_sharp.lib.utils import check_model_version logger = logging.get_logger(__name__) @@ -116,7 +116,7 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): + def forward(self, hidden_states, **kwargs): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -158,14 +158,15 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", - emb.cos().to(dtype), - persistent=False) - self.register_buffer("sin_cached", - emb.sin().to(dtype), - persistent=False) - - def forward(self, x, seq_len=None): + if not hasattr(self, "cos_cached") and not hasattr(self, "sin_cached"): + self.register_buffer("cos_cached", + emb.cos().to(dtype), + persistent=False) + self.register_buffer("sin_cached", + emb.sin().to(dtype), + persistent=False) + + def forward(self, x, seq_len=None, **kwargs): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, @@ -287,6 +288,7 @@ class LlamaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config + self.config.quip_params = self.config.quantization_config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.upgate_proj = FusedQuantizedLinear( @@ -325,7 +327,7 @@ def __init__(self, config): # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): + def forward(self, x, **kwargs): if self.config.pretraining_tp > 1: raise Exception # removed for quantization @@ -374,6 +376,7 @@ class LlamaAttention(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.config.quip_params = self.config.quantization_config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( @@ -848,6 +851,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -1175,6 +1179,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -1354,6 +1359,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1548,6 +1554,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): diff --git a/model/mistral.py b/src/quip_sharp/model/mistral.py similarity index 99% rename from model/mistral.py rename to src/quip_sharp/model/mistral.py index b478b1b..cb1bde4 100644 --- a/model/mistral.py +++ b/src/quip_sharp/model/mistral.py @@ -52,9 +52,9 @@ _flash_supports_window_size = "window_size" in list( inspect.signature(flash_attn_func).parameters) -from lib.linear.fused_quantized_linear import FusedQuantizedLinear -from lib.linear.quantized_linear import QuantizedLinear -from lib.utils import check_model_version +from quip_sharp.lib.linear.fused_quantized_linear import FusedQuantizedLinear +from quip_sharp.lib.linear.quantized_linear import QuantizedLinear +from quip_sharp.lib.utils import check_model_version logger = logging.get_logger(__name__) @@ -188,6 +188,7 @@ class MistralMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config + self.config.quip_params = self.config.quantization_config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.upgate_proj = FusedQuantizedLinear( @@ -258,6 +259,7 @@ class MistralAttention(nn.Module): def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.config.quip_params = self.config.quantization_config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( diff --git a/quantize_llama/example.sh b/src/quip_sharp/quantize_llama/example.sh similarity index 100% rename from quantize_llama/example.sh rename to src/quip_sharp/quantize_llama/example.sh diff --git a/quantize_llama/finetune_e2e_llama.py b/src/quip_sharp/quantize_llama/finetune_e2e_llama.py similarity index 97% rename from quantize_llama/finetune_e2e_llama.py rename to src/quip_sharp/quantize_llama/finetune_e2e_llama.py index bb7573c..003a7bd 100644 --- a/quantize_llama/finetune_e2e_llama.py +++ b/src/quip_sharp/quantize_llama/finetune_e2e_llama.py @@ -19,9 +19,9 @@ from transformers.modeling_attn_mask_utils import \ _prepare_4d_causal_attention_mask -from lib import codebook, utils -from lib.algo import finetune, quip -from lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.lib import codebook, utils +from quip_sharp.lib.algo import finetune, quip +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) diff --git a/quantize_llama/hessian_offline_llama.py b/src/quip_sharp/quantize_llama/hessian_offline_llama.py similarity index 99% rename from quantize_llama/hessian_offline_llama.py rename to src/quip_sharp/quantize_llama/hessian_offline_llama.py index 24cdf79..d49f84c 100644 --- a/quantize_llama/hessian_offline_llama.py +++ b/src/quip_sharp/quantize_llama/hessian_offline_llama.py @@ -16,7 +16,7 @@ from transformers.modeling_attn_mask_utils import \ _prepare_4d_causal_attention_mask -from lib import utils +from quip_sharp.lib import utils parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) diff --git a/quantize_llama/hfize_llama.py b/src/quip_sharp/quantize_llama/hfize_llama.py similarity index 94% rename from quantize_llama/hfize_llama.py rename to src/quip_sharp/quantize_llama/hfize_llama.py index 3aa5e03..40467a2 100644 --- a/quantize_llama/hfize_llama.py +++ b/src/quip_sharp/quantize_llama/hfize_llama.py @@ -6,10 +6,10 @@ import torch from transformers import AutoTokenizer -from lib import codebook, utils -from lib.utils.unsafe_import import model_from_hf_path -from model.llama import LlamaForCausalLM -from lib.utils.model_version import MODEL_VERSION +from quip_sharp.lib import codebook, utils +from quip_sharp.lib.utils.unsafe_import import model_from_hf_path +from quip_sharp.model.llama import LlamaForCausalLM +from quip_sharp.lib.utils.model_version import MODEL_VERSION torch.set_grad_enabled(False) diff --git a/quantize_llama/quantize_finetune_llama.py b/src/quip_sharp/quantize_llama/quantize_finetune_llama.py similarity index 96% rename from quantize_llama/quantize_finetune_llama.py rename to src/quip_sharp/quantize_llama/quantize_finetune_llama.py index af33222..d027fd2 100644 --- a/quantize_llama/quantize_finetune_llama.py +++ b/src/quip_sharp/quantize_llama/quantize_finetune_llama.py @@ -12,10 +12,10 @@ from transformers.modeling_attn_mask_utils import \ _prepare_4d_causal_attention_mask -from lib import codebook, utils -from lib.algo import finetune, quip -from lib.linear import FusedLinear -from model.llama import LlamaDecoderLayer +from quip_sharp.lib import codebook, utils +from quip_sharp.lib.algo import finetune, quip +from quip_sharp.lib.linear import FusedLinear +from quip_sharp.model.llama import LlamaDecoderLayer parser = argparse.ArgumentParser() parser.add_argument('--seed', default=0, type=int) @@ -139,6 +139,7 @@ def main(args): # save configs all_config = {'quant_args': args, 'model_config': model.config} quip_params = { + 'quant_method': 'quip-sharp', 'lora_rank': args.lora_rank, 'rescale_WH': args.rescale_WH, 'codebook': args.codebook, @@ -148,7 +149,7 @@ def main(args): 'packsz': cb.packsz, 'resid_scale_override': args.resid_scale_override, } - all_config['model_config'].update({'quip_params': quip_params}) + all_config['model_config'].update({'quantization_config': quip_params}) torch.save(all_config, os.path.join(args.save_path, 'config.pt')) tokenizer = AutoTokenizer.from_pretrained(args.base_model) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llama2-7b-2bit.sh b/tests/llama2-7b-2bit.sh new file mode 100755 index 0000000..d04b2fe --- /dev/null +++ b/tests/llama2-7b-2bit.sh @@ -0,0 +1,77 @@ +#!/bin/bash +set -euxo pipefail + + +CKPT=/share/desa/nfs01/qs234/quip-sharp/checkpoints +HF=/share/desa/nfs01/qs234/quip-sharp/hfized +HESS=/share/desa/nfs01/qs234/huggingface/hub/models--relaxml--Hessians-Llama-2-7b-6144/snapshots/cafc59c036c6416ec2a9d5790752bec51297c197/ +LOG=/share/desa/nfs01/qs234/quip-sharp/logs + + +mkdir -p $CKPT +mkdir -p $HF +mkdir -p $LOG + + +# quantize with finetuning +python3 \ + -m quip_sharp.quantize_llama.quantize_finetune_llama \ + --save_path $CKPT/2_7b_2bit \ + --codebook E8P12 \ + --scale_override 0.9 \ + --base_model meta-llama/Llama-2-7b-hf \ + --hessian_path $HESS \ + --devset_size 384 \ + --ft_valid_size 128 \ + --ft_epochs 8 \ + 2>&1 \ + | tee -a $LOG/2_7b_2bit + + +# convert model to hf format for end to end fine tuning +CUDA_VISIBLE_DEVICES=0 python3 \ + -m quip_sharp.quantize_llama.hfize_llama \ + --quantized_path $CKPT/2_7b_2bit \ + --hf_output_path $HF/2_7b_2bit \ + 2>&1 \ + | tee -a $LOG/2_7b_2bit + + +# end to end fine tuning +# python3 \ +# -m quip_sharp.quantize_llama.finetune_e2e_llama \ +# --base_model meta-llama/Llama-2-7b-hf \ +# --hf_path $HF/2_7b_2bit \ +# --devset_size 384 \ +# --ft_valid_size 128 \ +# --ft_epochs 8 \ +# --ft_bs 1 \ +# --ctx_size 4096 \ +# --ft_update_freq 2 \ +# --ft_train_mode \ +# --ckpt_path $CKPT/2_7b_2bit \ +# 2>&1 \ +# | tee -a $LOG/2_7b_2bit + + +# eval +CUDA_VISIBLE_DEVICES=0 python3 \ + -m quip_sharp.quantize_llama.hfize_llama \ + --quantized_path $CKPT/2_7b_2bit \ + --hf_output_path $HF/2_7b_2bit \ + 2>&1 \ + | tee -a $LOG/2_7b_2bit + +CUDA_VISIBLE_DEVICES=0 python3 \ + -m quip_sharp.eval.eval_ppl \ + --hf_path $HF/2_7b_2bit \ + 2>&1 \ + | tee -a $LOG/2_7b_2bit + +CUDA_VISIBLE_DEVICES=0 python3 \ + -m quip_sharp.eval.eval_zeroshot \ + --tasks arc_challenge,arc_easy,boolq,piqa,winogrande \ + --batch_size 4 \ + 2>&1 \ + --hf_path $HF/2_7b_2bit \ + | tee -a $LOG/2_7b_2bit