diff --git a/mindnlp/quant/mindbnb/CMakeLists.txt b/mindnlp/quant/mindbnb/CMakeLists.txt new file mode 100644 index 000000000..da87ce0f3 --- /dev/null +++ b/mindnlp/quant/mindbnb/CMakeLists.txt @@ -0,0 +1,241 @@ +# This CMake config hopefully makes it easier to compile. +# Ensure the CUDA Toolkit is available on your path. Then run: +# For GCC: `cmake -B build . && cmake --build build` +# For MSVC: `cmake -B build . && cmake --build build --config Release` +# You can also use the following options and variables +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support +# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version +# is whatever CMake finds on your path. +# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. +# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90` +# Check your compute capability here: https://developer.nvidia.com/cuda-gpus +# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler +cmake_minimum_required(VERSION 3.22.1) + +project(bitsandbytes LANGUAGES CXX) + +# If run without specifying a build type, default to using the Release configuration: +# optimizing the generated binaries for performance and also adds the `-DNDEBUG` flag, +# which turns off a bunch of asserts which seem to link to new symbols in libstdc++, +# worsening our many_linux compliance.. +# if(NOT CMAKE_BUILD_TYPE) +# set(CMAKE_BUILD_TYPE Release) +# endif() +# Set the build type to Debug to include debug information +set(CMAKE_BUILD_TYPE Debug CACHE STRING "Build type" FORCE) +set(CMAKE_BUILD_TYPE Debug) + +# Define included source files +set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) +set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(MPS_FILES csrc/mps_ops.mm) +set(METAL_FILES csrc/mps_kernels.metal) +# C++ sources are always included +list(APPEND SRC_FILES ${CPP_FILES}) + +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) + +if(APPLE) + set(CMAKE_OSX_DEPLOYMENT_TARGET 13.1) +endif() + +set(BNB_OUTPUT_NAME "bitsandbytes") + +message(STATUS "Configuring ${PROJECT_NAME} (Backend: ${COMPUTE_BACKEND})") + +if(${COMPUTE_BACKEND} STREQUAL "cuda") + if(APPLE) + message(FATAL_ERROR "CUDA is not supported on macOS" ) + endif() + option(NO_CUBLASLT "Disable CUBLAS" OFF) + set(BUILD_CUDA ON) + set(BUILD_MPS OFF) + message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") +elseif(${COMPUTE_BACKEND} STREQUAL "mps") + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_MPS ON) +else() + set(BUILD_CUDA OFF) + set(BUILD_MPS OFF) +endif() + + +if(BUILD_CUDA) + # NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+. + # Workaround: use --allow-unsupported-compiler + # This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes. + if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940) + string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler") + endif() + + enable_language(CUDA) # This will fail if CUDA is not found + find_package(CUDAToolkit REQUIRED) + + # Convert the CUDA version from X.Y.z to XY. There's probably a shorter way of doing this + string(REGEX MATCH "^[0-9]+.[0-9]+" _CUDA_VERSION_FIRST_TWO "${CMAKE_CUDA_COMPILER_VERSION}") + string(REPLACE "." "" CUDA_VERSION_SHORT "${_CUDA_VERSION_FIRST_TWO}") + + # Expose a cache variable that the user can set to ensure the correct version of CUDA is found + set(CUDA_VERSION "${CUDA_VERSION_SHORT}" CACHE STRING "Expected CUDA Version Shortcode") + + message(STATUS "CUDA Version: ${CUDA_VERSION_SHORT} (${CMAKE_CUDA_COMPILER_VERSION})") + message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + + # It should match the discovered version + if(NOT CUDA_VERSION STREQUAL "${CUDA_VERSION_SHORT}") + message(FATAL_ERROR "You've specified CUDA version ${CUDA_VERSION} however the CUDA compiler found is ${CUDA_VERSION_SHORT}." + " Ensure the desired CUDA compiler is the first one available on your PATH." + ) + endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "11.0") + message(FATAL_ERROR "CUDA Version < 11 is not supported") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + message(FATAL_ERROR "CUDA Version > 12 is not supported") + endif() + + # CMake < 3.23.0 does not define CMAKE_CUDA_ARCHITECTURES_ALL. + if(CMAKE_VERSION VERSION_LESS "3.23.0") + message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...") + + # 11.x and 12.x both support these at a minimum. + set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80) + + # CUDA 11.1 adds Ampere support for GA102-GA107. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.1") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 86) + endif() + + # CUDA 11.4 adds Ampere support for GA10B. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.4") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 87) + endif() + + # CUDA 11.8 adds support for Ada and Hopper. + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8") + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL 89 90) + list(APPEND CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 90) + endif() + endif() + + string(APPEND CMAKE_CUDA_FLAGS " --use_fast_math") + + if(PTXAS_VERBOSE) + # Verbose? Outputs register usage information, and other things... + string(APPEND CMAKE_CUDA_FLAGS " -Xptxas=-v") + endif() + + foreach(capability ${CMAKE_CUDA_ARCHITECTURES_ALL}) + # Most of the items here are like: `xx-real`, so we just extract the `xx` portion + string(REGEX MATCH "[0-9]+" capability_id "${capability}") + if(capability_id GREATER 0) + list(APPEND POSSIBLE_CAPABILITIES ${capability_id}) + endif() + endforeach() + + # This can be changed via -D argument to CMake + # By default all possible capabilities are compiled + set(COMPUTE_CAPABILITY "${POSSIBLE_CAPABILITIES}" CACHE STRING "Compute Capabilities Targeted") + + message(STATUS "CUDA Capabilities Available: ${POSSIBLE_CAPABILITIES}") + message(STATUS "CUDA Capabilities Selected: ${COMPUTE_CAPABILITY}") + + # Use the "real" option to build native cubin for all selections. + # Ensure we build the PTX for the latest version. + # This behavior of adding a PTX (virtual) target for the highest architecture + # is similar to how the "all" and "all-major" options would behave in CMake >= 3.23. + # TODO: Consider bumping CMake requirement and using CMAKE_CUDA_ARCHITECTURES=[all | native] by default + list(REMOVE_DUPLICATES COMPUTE_CAPABILITY) + list(SORT COMPUTE_CAPABILITY COMPARE NATURAL) + list(POP_BACK COMPUTE_CAPABILITY _LATEST_CAPABILITY) + list(TRANSFORM COMPUTE_CAPABILITY APPEND "-real" OUTPUT_VARIABLE CMAKE_CUDA_ARCHITECTURES) + list(APPEND CMAKE_CUDA_ARCHITECTURES ${_LATEST_CAPABILITY}) + + message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}") + + list(APPEND SRC_FILES ${CUDA_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") + if(NO_CUBLASLT) + string(APPEND BNB_OUTPUT_NAME "_nocublaslt") + endif() + add_compile_definitions(BUILD_CUDA) +elseif(BUILD_MPS) + if(NOT APPLE) + message(FATAL_ERROR "MPS is only supported on macOS" ) + endif() + + enable_language(OBJCXX) + + list(APPEND SRC_FILES ${MPS_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_mps") + add_compile_definitions(BUILD_MPS) + file(MAKE_DIRECTORY "build") + add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" + COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} + COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" + DEPENDS "${METAL_FILES}" + COMMENT "Compiling Metal kernels" + VERBATIM) + add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") +else() + string(APPEND BNB_OUTPUT_NAME "_cpu") + set(GPU_SOURCES) +endif() + + +if(WIN32) + # Export all symbols + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") +endif() + +set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) +add_library(bitsandbytes SHARED ${SRC_FILES}) +target_compile_features(bitsandbytes PUBLIC cxx_std_14) +target_include_directories(bitsandbytes PUBLIC csrc include) + + +if(BUILD_CUDA) + target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) + if(NO_CUBLASLT) + target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) + else() + target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) + endif() + + set_target_properties(bitsandbytes + PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + ) +endif() +if(BUILD_MPS) + add_dependencies(bitsandbytes metallib) + target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") +endif() + +if(WIN32) + set_target_properties(bitsandbytes PROPERTIES PREFIX "lib") +endif() +set_target_properties(bitsandbytes PROPERTIES OUTPUT_NAME ${BNB_OUTPUT_NAME}) +if(MSVC) + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_RELEASE "${PROJECT_SOURCE_DIR}/bitsandbytes") + set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") +endif() + +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") diff --git a/mindnlp/quant/mindbnb/README.md b/mindnlp/quant/mindbnb/README.md new file mode 100644 index 000000000..0f462595c --- /dev/null +++ b/mindnlp/quant/mindbnb/README.md @@ -0,0 +1,7 @@ +# MindBNB + +quantization for mindnlp + +## Setup + +bash /path/to/mindbnb/scripts/build.sh diff --git a/mindnlp/quant/mindbnb/__init__.py b/mindnlp/quant/mindbnb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/quant/mindbnb/bitsandbytes/__init__.py b/mindnlp/quant/mindbnb/bitsandbytes/__init__.py new file mode 100644 index 000000000..a72a26200 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + mainly from bitsandbytes repo +''' +from . import utils +from .autograd._functions import ( + MatmulLtState, + matmul, +) +from .nn import modules + +__pdoc__ = { + "libbitsandbytes": False, +} diff --git a/mindnlp/quant/mindbnb/bitsandbytes/autograd/__init__.py b/mindnlp/quant/mindbnb/bitsandbytes/autograd/__init__.py new file mode 100644 index 000000000..7620b91b6 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/autograd/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +autograd +''' +from ._functions import get_inverse_transform_indices, undo_layout diff --git a/mindnlp/quant/mindbnb/bitsandbytes/autograd/_functions.py b/mindnlp/quant/mindbnb/bitsandbytes/autograd/_functions.py new file mode 100644 index 000000000..36a0e0d25 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/autograd/_functions.py @@ -0,0 +1,422 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + _functions +""" +# pylint: disable=E0401, E0611 +from dataclasses import dataclass +from functools import reduce # Required in Python 3 +import operator +from typing import Callable, Optional, Tuple +import warnings +import mindspore +from mindspore import ops, Tensor, nn +import bitsandbytes.functional as F +from mindspore._c_expression import ( + Tensor as CTensor, +) # pylint: disable=no-name-in-module, import-error + + +def empty(*size, dtype=None): + if isinstance(size[0], (tuple, list)): + size = size[0] + out = CTensor(dtype, size) + return mindspore.Tensor(out) + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +def clone(tensor): + return tensor.copy() + + +# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: +# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py + + +""" + This class pools outlier dimensions across layers. + This is particularly important for small models where outlier features + are less systematic and occur with low frequency. +""" + + +class GlobalOutlierPooler: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.outliers = set() + self.model_dim = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def add_outliers(self, outlier_idx, feature_dim): + if self.model_dim is None: + self.model_dim = feature_dim + if feature_dim != self.model_dim: + return # we do not encode outliers for the 2nd FFN layer + + self.outliers.update(outlier_idx.tolist()) + + def get_current_outlier_idx(self): + return mindspore.Tensor(list(self.outliers)).to(mindspore.int64) + + +def get_inverse_transform_indices( + transform_tile: Callable[[mindspore.Tensor], mindspore.Tensor], + tile_size: Tuple[int, int], +): + """ + Compute a permutation of indices that invert the specified (tiled) matrix transformation + + :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] + :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere + :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size + :example: transform_tile function for the turing layout (bitsandbytes.functional as F) + :returns: indices + """ + d1, d2 = tile_size + assert 0 < d1 * d2 < 2**64 + tile_indices = ops.arange(d1 * d2, dtype=mindspore.int64).view(d1, d2) + # encode each position in tile as a tuple of <= 8 unique bytes + permuted_tile_indices = ops.zeros_like(tile_indices) + for i in range(8): + # select i-th byte, apply transformation and trace where each index ended up + ith_dim_indices = ops.div(tile_indices, 256**i, rounding_mode="trunc") % 256 + sample_tile_i = (ith_dim_indices - 128).to(mindspore.int8).contiguous() + assert ops.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" + permuted_tile_i = transform_tile(sample_tile_i) + ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 + permuted_tile_indices += ith_permuted_indices * (256**i) + if d1 * d2 < 256**i: + break # if all indices fit in i bytes, stop early + return permuted_tile_indices + + +def undo_layout( + permuted_tensor: mindspore.Tensor, tile_indices: mindspore.Tensor +) -> mindspore.Tensor: + """ + Undo a tiled permutation such as turing or ampere layout + + :param permuted_tensor: mindspore tensor in a permuted layout + :param tile_indices: reverse transformation indices, from get_inverse_transform_indices + :return: contiguous row-major tensor + """ + (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape + assert ( + rows % tile_rows == cols % tile_cols == 0 + ), "tensor must contain a whole number of tiles" + tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() + outputs = Tensor( + shape=tensor.shape, dtype=tensor.dtype + ) # note: not using .index_copy because it was slower on cuda + outputs[tile_indices.flatten()] = tensor + outputs = outputs.reshape( + tile_rows, tile_cols, cols // tile_cols, rows // tile_rows + ) + outputs = outputs.permute( + 3, 0, 2, 1 + ) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) + return outputs.reshape(rows, cols).contiguous() + + +class MatMul8bit: + @staticmethod + def construct(ctx, A, B, out=None, quant_type="vector", precision=None): + if precision is None: + precision = [8, 8, 8] + if precision[0] != 8: + output = ops.matmul(A, B) + else: + if len(B.shape) == 2: + dim = 0 + else: + dim = 1 + qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) + qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) + iout = F.igemm(qA, qB) + output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type) + + if A.requires_grad or B.requires_grad: + ctx.save_for_backward(A, B) + + ctx.quant_type = quant_type + ctx.precision = precision + + return output + + +def supports_igemmlt() -> bool: + """检查当前设备是否支持优化的 int8 内核""" + device_name = F.GPU_NAME + if device_name not in F.gpus_compute_capability_over_7_5: + return False + else: + nvidia16_models = ( + "NVIDIA GeForce GTX 1630", + "NVIDIA GeForce GTX 1650", + "NVIDIA GeForce GTX 1660", + ) # https://en.wikipedia.org/wiki/GeForce_16_series + if any(model_name in device_name for model_name in nvidia16_models): + return False # 这些设备在技术上是 cuda 7.5 兼容的,但缺少张量核心 + + return True + + +def _get_tile_size(format): + assert format in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {format}" + return (8, 32) if format == "col_turing" else (32, 32) + + +def get_tile_inds(format, device): + def transform(x): + return F.transform(x, from_order="row", to_order=format)[0].to(x.device) + + return get_inverse_transform_indices(transform, _get_tile_size(format)) + + +@dataclass +class MatmulLtState: + + _tile_indices: Optional[mindspore.Tensor] = None + force_no_igemmlt: bool = False + CB = None + CxB = None + SB = None + SCB = None + + CxBt = None + SBt = None + CBt = None + + subB = None + + outlier_pool = None + has_accumulated_gradients = False + threshold = 0.0 + idx = None + is_training = True + has_fp16_weights = True + memory_efficient_backward = False + use_pool = False + formatB = F.get_special_format_str() + + def reset_grads( + self, + ): + self.CB = None + self.CxB = None + self.SB = None + self.SCB = None + + self.CxBt = None + self.SBt = None + self.CBt = None + + @property + def tile_indices(self): + if self._tile_indices is None: + self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) + return self._tile_indices + + +class MatMul8bitLt(nn.Cell): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + def __init__( + self, + ): + super().__init__() + self.needs_input_grad = [False, False, False, False, False] + + def construct(self, A, B, out=None, bias=None, state=MatmulLtState): + using_igemmlt = supports_igemmlt() and not state.force_no_igemmlt + # default of pymindspore behavior if inputs are empty + self.is_empty = False + if prod(A.shape) == 0: + self.is_empty = True + self.A = A + self.B = B + self.bias = bias + if A.shape[-1] == B.shape[0]: + return empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype) + else: + return empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype) + + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + if A.dtype != mindspore.float16: + warnings.warn( + f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization" + ) + # 1. Quantize A + if len(A.shape) == 3: + A = A.reshape(-1, A.shape[-1]) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A.astype(mindspore.float16), threshold=state.threshold + ) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + _, idx = ops.unique(coo_tensorA.colidx) + idx.astype(mindspore.int64) + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t() + state.idx = idx + else: + if state.CxB is None and using_igemmlt: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + if not state.has_fp16_weights and state.CxB is None and using_igemmlt: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + # 2. Quantize B + if state.has_fp16_weights: + has_grad = getattr(B, "grad", None) is not None + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B.to(mindspore.float16)) + if using_igemmlt: + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + state.CB = CB + else: + has_grad = False + + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx, _ = ops.unique(coo_tensorA.colidx) + state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + if state.CxB is not None: + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + else: + outliers = state.CB[:, state.idx.long()].clone() + + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().to(A.dtype) + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + + shapeB = state.SB[0] if state.SB else B.shape + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + # 3. Matmul + if using_igemmlt: + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + if bias is None or bias.dtype == mindspore.float16: + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype) + bias + + else: + A_wo_outliers = A.copy() + if state.idx is not None: + A_wo_outliers[:, state.idx.long()] = 0 + output = ops.dense(A_wo_outliers, state.CB.to(A.dtype)) + scb = state.SCB.unsqueeze(0) + scb = scb * (1.0 / 127.0) + output = output * scb + if bias is not None: + output = output + bias + # 4. Mixed-precision decomposition matmul + if coo_tensorA is not None and subA is not None: + output += ops.matmul(subA, state.subB) + # 5. Save state + self.state = state + + self.formatB = formatB + self.grad_shape = input_shape + self.dtype_A, self.dtype_B, self.dtype_bias = ( + A.dtype, + B.dtype, + None if bias is None else bias.dtype, + ) + + if any(self.needs_input_grad[:2]): + self.tensors = (CAt, subA, A) + self.tensor_states = (SCAt, state.idx) + else: + self.tensors = [None, None, A] + self.tensor_states = (None, None) + + clone_func = clone if len(output_shape) == 3 else lambda x: x + + return clone_func(output.view(output_shape)) + + +matmul8bitlt = MatMul8bitLt() + + +def matmul( + A: mindspore.Tensor, + B: mindspore.Tensor, + out: Optional[mindspore.Tensor] = None, + state: Optional[MatmulLtState] = None, + threshold=0.0, + bias=None, +): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + # return MatMul8bitLt(A, B, out, bias, state) + return matmul8bitlt(A, B, out, bias, state) diff --git a/mindnlp/quant/mindbnb/bitsandbytes/bnbop.py b/mindnlp/quant/mindbnb/bitsandbytes/bnbop.py new file mode 100644 index 000000000..e586105df --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/bnbop.py @@ -0,0 +1,130 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + custom ops +''' +# pylint: disable=E0401 +import mindspore +from mindspore import ops + +# from mindspore.ops import custom_info_register, CustomRegOp, DataType + +from bitsandbytes.lib import lib_path + + +cget_col_row_stats = ops.Custom( + f"{lib_path}:custom_cget_col_row_stats", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cdouble_rowcol_quant = ops.Custom( + f"{lib_path}:custom_cdouble_rowcol_quant", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2col32T = ops.Custom( + f"{lib_path}:custom_ctransform_row2col32T", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2col32 = ops.Custom( + f"{lib_path}:custom_ctransform_row2col32", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2turingT = ops.Custom( + f"{lib_path}:custom_ctransform_row2turingT", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2turing = ops.Custom( + f"{lib_path}:custom_ctransform_row2turing", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2ampereT = ops.Custom( + f"{lib_path}:custom_ctransform_row2ampereT", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +ctransform_row2ampere = ops.Custom( + f"{lib_path}:custom_ctransform_row2ampere", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cextractOutliers_turing = ops.Custom( + f"{lib_path}:custom_cextractOutliers_turing", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cextractOutliers_ampere = ops.Custom( + f"{lib_path}:custom_cextractOutliers_ampere", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cigemmlt_turing_32 = ops.Custom( + f"{lib_path}:custom_cigemmlt_turing_32", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cigemmlt_turing_8 = ops.Custom( + f"{lib_path}:custom_cigemmlt_turing_64", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cigemmlt_ampere_32 = ops.Custom( + f"{lib_path}:custom_cigemmlt_ampere_32", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cigemmlt_ampere_8 = ops.Custom( + f"{lib_path}:custom_cigemmlt_ampere_64", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) + +cdequant_mm_int32_fp16 = ops.Custom( + f"{lib_path}:custom_cdequant_mm_int32_fp16", + out_shape=([1]), + out_dtype=mindspore.int32, + func_type="aot", +) diff --git a/mindnlp/quant/mindbnb/bitsandbytes/consts.py b/mindnlp/quant/mindbnb/bitsandbytes/consts.py new file mode 100644 index 000000000..da199ecf4 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/consts.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +path info +''' +from pathlib import Path +import platform + +DYNAMIC_LIBRARY_SUFFIX = { + "Darwin": ".dylib", + "Linux": ".so", + "Windows": ".dll", +}.get(platform.system(), ".so") + +PACKAGE_DIR = Path(__file__).parent +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/mindnlp/quant/mindbnb/bitsandbytes/cuda_specs.py b/mindnlp/quant/mindbnb/bitsandbytes/cuda_specs.py new file mode 100644 index 000000000..7d33219e8 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/cuda_specs.py @@ -0,0 +1,52 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + cuda_specs +''' +import dataclasses +from typing import Optional, Tuple +import subprocess +import re +from mindspore import context + + +@dataclasses.dataclass(frozen=True) +class CUDASpecs: + cuda_version_string: str + cuda_version_tuple: Tuple[int, int] + + +def get_cuda_version_tuple() -> Tuple[int, int]: + result = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE, text=True, check=True) + match = re.search(r"V(\d+)\.(\d+)", result.stdout) + if match: + major, minor = map(int, match.groups()) + return major, minor + return 0, 0 + + +def get_cuda_version_string() -> str: + major, minor = get_cuda_version_tuple() + return f"{major}{minor}" + + +def get_cuda_specs() -> Optional[CUDASpecs]: + if not context.get_context("device_target") == "GPU": + return None + + return CUDASpecs( + cuda_version_string=(get_cuda_version_string()), + cuda_version_tuple=get_cuda_version_tuple(), + ) diff --git a/mindnlp/quant/mindbnb/bitsandbytes/functional.py b/mindnlp/quant/mindbnb/bitsandbytes/functional.py new file mode 100644 index 000000000..958b87534 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/functional.py @@ -0,0 +1,1304 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + mindbnb functions +""" +# pylint: disable=E0611, E0401 +from functools import reduce # Required in Python 3 +import itertools +import operator +import subprocess +from typing import Any, Dict +import ctypes as ct +import numpy as np +import mindspore +from mindspore import Tensor +from mindspore import ops +from mindspore import context +from mindspore._c_expression import ( + Tensor as CTensor, +) # pylint: disable=no-name-in-module, import-error + +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict + +from bitsandbytes import bnbop + +gpus_compute_capability_over_7_5 = [ + # Compute Capability 7.5 + "Tesla T4", + "Quadro T1000", + "Quadro T2000", + "Quadro RTX 3000", + "Quadro RTX 4000", + "Quadro RTX 5000", + "Quadro RTX 6000", + "Quadro RTX 8000", + "NVIDIA GeForce GTX 1650", + "NVIDIA GeForce GTX 1660", + "NVIDIA GeForce GTX 1660 Ti", + # Compute Capability 8.0 + "NVIDIA A100-SXM4-40GB", + "NVIDIA A100-SXM4-80GB", + "NVIDIA A100-PCIe-40GB", + "NVIDIA A100-PCIe-80GB", + "NVIDIA GeForce RTX 3070", + "NVIDIA GeForce RTX 3080", + "NVIDIA GeForce RTX 3090", + "NVIDIA GeForce RTX 3080 Ti", + "NVIDIA GeForce RTX 3090 Ti", + "NVIDIA RTX A40", + "NVIDIA RTX A10", + # Compute Capability 8.6 + "NVIDIA GeForce RTX 3050", + "NVIDIA GeForce RTX 3060", + "NVIDIA GeForce RTX 3060 Ti", + "NVIDIA GeForce RTX 3070 Ti", + # Compute Capability 8.7 + "NVIDIA GeForce RTX 4080", + "NVIDIA GeForce RTX 4090", + "NVIDIA RTX A4500", + "NVIDIA RTX A5500", + "NVIDIA RTX A6000", + # Compute Capability 9.0 + "NVIDIA H100-SXM5-80GB", + "NVIDIA H100-PCIe-80GB", + # Compute Capability 9.1 + "NVIDIA RTX 6000 Ada Generation", + # Compute Capability 9.2 + "NVIDIA RTX 4000 Ada Generation", + "NVIDIA RTX 5000 Ada Generation", +] + +turing_gpus = [ + # GeForce RTX 20 Series + "NVIDIA GeForce RTX 2080 Ti", + "NVIDIA GeForce RTX 2080 Super", + "NVIDIA GeForce RTX 2080", + "NVIDIA GeForce RTX 2070 Super", + "NVIDIA GeForce RTX 2070", + "NVIDIA GeForce RTX 2060 Super", + "NVIDIA GeForce RTX 2060", + # GeForce GTX 16 Series + "NVIDIA GeForce GTX 1660 Ti", + "NVIDIA GeForce GTX 1660 Super", + "NVIDIA GeForce GTX 1660", + "NVIDIA GeForce GTX 1650 Super", + "NVIDIA GeForce GTX 1650", + # Quadro RTX Series + "Quadro RTX 8000", + "Quadro RTX 6000", + "Quadro RTX 5000", + "Quadro RTX 4000", + # Titan RTX + "Titan RTX", +] + +ampere_gpus = [ + # GeForce RTX 30 Series + "NVIDIA GeForce RTX 3090", + "NVIDIA GeForce RTX 3080 Ti", + "NVIDIA GeForce RTX 3080", + "NVIDIA GeForce RTX 3070 Ti", + "NVIDIA GeForce RTX 3070", + "NVIDIA GeForce RTX 3060 Ti", + "NVIDIA GeForce RTX 3060", + # Quadro RTX Series + "Quadro RTX A6000", + "Quadro RTX A5000", + "Quadro RTX A4000", +] + + +def empty(*size, dtype): + if isinstance(size[0], (tuple, list)): + size = size[0] + out = CTensor(dtype, size) + return mindspore.Tensor(out) + + +def get_gpu_name(device: int): + + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + if result.returncode != 0: + raise RuntimeError(f"nvidia-smi error: {result.stderr}") + gpu_name = result.stdout.strip() + return gpu_name.split("\n")[device] + except FileNotFoundError: + return ( + "nvidia-smi command not found. Make sure you have NVIDIA drivers installed." + ) + + +GPU_NAME = get_gpu_name(context.get_context("device_id")) + + +def get_special_format_str(): + device_target = context.get_context("device_target") + if device_target == "CPU": + return "col_turing" + + device_name = GPU_NAME + if device_name in turing_gpus: + return "col_turing" + if device_name in ampere_gpus: + return "col_ampere" + + return "col_turing" + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +name2qmap = {} + + +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +class Cusparse_Context: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.context = ct.c_void_p(bnbop.get_cusparse()) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +dtype2bytes = {} +dtype2bytes[mindspore.float32] = 4 +dtype2bytes[mindspore.float16] = 2 +dtype2bytes[mindspore.bfloat16] = 2 +dtype2bytes[mindspore.uint8] = 1 +dtype2bytes[mindspore.int8] = 1 + + +def frombuffer(buffer, dtype, count, shape): + # 使用 numpy.frombuffer 创建一个 NumPy 数组 + np_array = np.frombuffer(buffer, dtype=dtype, count=count) + # 将 NumPy 数组转换为 MindSpore Tensor + tensor = Tensor(np_array).reshape(shape) + return tensor + + +def get_paged(*shape, dtype=mindspore.float32): + num_bytes = dtype2bytes[dtype] * prod(shape) + cuda_ptr = bnbop.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = frombuffer(new_array, dtype=dtype, count=prod(shape), shape=shape) + out.is_paged = True + return out + + +def create_linear_map(signed=True, total_bits=8, add_zero=True): + sign = -1.0 if signed else 0.0 + total_values = 2**total_bits + if add_zero or total_bits < 8: + # add a zero + # since we simulate less bits by having zeros in the data type, we + # we need to center the quantization around zero and as such lose + # a single value + total_values = 2**total_bits if not signed else 2**total_bits - 1 + + values = ops.linspace(sign, 1.0, total_values) + gap = 256 - values.numel() + if gap == 0: + return values + else: + l = values.numel() // 2 # noqa: E741 + return mindspore.Tensor(values[:l].tolist() + [0] * gap + values[l:].tolist()) + + +def create_normal_map(offset=0.9677083, use_extra_value=True): + try: + from scipy.stats import norm + except ImportError as ie: + raise ImportError( + "Scipy is required for `create_normal_map`. Install `bitsandbytes` with the `[test]` extra.", + ) from ie + + if use_extra_value: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(ops.linspace(offset, 0.5, 9)[:-1]).tolist() + v2 = [0] * (256 - 15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(ops.linspace(offset, 0.5, 8)[:-1])).tolist() + else: + v1 = norm.ppf(ops.linspace(offset, 0.5, 8)[:-1]).tolist() + v2 = [0] * (256 - 14) ## we have 14 non-zero values in this data type + v3 = (-norm.ppf(ops.linspace(offset, 0.5, 8)[:-1])).tolist() + + v = v1 + v2 + v3 + + values = mindspore.Tensor(v) + values = values.sort().values + values /= values.max() + + assert values.numel() == 256 + + return values + + +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): + e = exponent_bits + p = precision_bits + has_sign = 1 if signed else 0 + assert e + p == total_bits - has_sign + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate( + range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1) + ): + evalues.append(2**val) + + values = [] + lst = list(itertools.product([0, 1], repeat=precision_bits)) + # for ev in evalues: + bias = 2 ** (exponent_bits - 1) + for evalue in range(2 ** (exponent_bits)): + for bit_pattern in lst: + value = 1 if evalue != 0 else 0 + for i, pval in enumerate(list(bit_pattern)): + value += pval * (2 ** -(i + 1)) + if evalue == 0: + # subnormals + value = value * 2 ** -(bias) + else: + # normals + value = value * 2 ** -(evalue - bias - 1) + values.append(value) + if signed: + values.append(-value) + + assert len(values) == 2**total_bits + values.sort() + if total_bits < 8: + gap = 256 - len(values) + for i in range(gap): + values.append(0) + values.sort() + code = mindspore.Tensor(values) + code /= code.max() + + return code + + +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): + """ + Creates the dynamic quantiztion map. + + The dynamic data type is made up of a dynamic exponent and + fraction. As the exponent increase from 0 to -7 the number + of bits available for the fraction shrinks. + + This is a generalization of the dynamic type where a certain + number of the bits and be reserved for the linear quantization + region (the fraction). n determines the maximum number of + exponent bits. + + For more details see + (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] + """ + + data = [] + # these are additional items that come from the case + # where all the exponent bits are zero and no + # indicator bit is present + non_sign_bits = total_bits - (1 if signed else 1) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 + for i in range(max_exponent_bits): + fraction_items = int( + ( + 2 ** (i + non_sign_bits - max_exponent_bits) + 1 + if signed + else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1 + ), + ) + boundaries = ops.linspace(0.1, 1, fraction_items) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + if additional_items > 0: + boundaries = ops.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + data.append(0) + data.append(1.0) + + assert len(data) == 2**total_bits + + gap = 256 - len(data) + for i in range(gap): + data.append(0) + + data.sort() + return Tensor(data) + + +def get_transform_func(dtype, orderA, orderOut, transpose=False): + name = f'ctransform_{(8 if dtype == mindspore.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' + if not hasattr(bnbop, name): + print(name) + raise ValueError( + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", + ) + else: + return getattr(bnbop, name) + + +def get_transform_buffer(shape, dtype, to_order, from_order="row", transpose=False): + init_func = ops.zeros + dims = len(shape) + + rows = shape[0] + if dims == 3: + rows = rows * shape[1] + cols = shape[-1] + + state = (shape, to_order) + if transpose: + # swap dims + rows, cols = cols, rows + state = (shape[::-1], to_order) + + if to_order in ("row", "col"): + return ( + init_func( + shape, + dtype=dtype, + ), + state, + ) + elif to_order == "col32": + # blocks of 32 columns (padded) + cols = 32 * ((cols + 31) // 32) + return ( + init_func( + (rows, cols), + dtype=dtype, + ), + state, + ) + elif to_order == "col_turing": + # blocks of 32 columns and 8 rows + cols = 32 * ((cols + 31) // 32) + rows = 8 * ((rows + 7) // 8) + return ( + init_func( + (rows, cols), + dtype=dtype, + ), + state, + ) + elif to_order == "col_ampere": + # blocks of 32 columns and 32 rows + cols = 32 * ((cols + 31) // 32) + rows = 32 * ((rows + 31) // 32) + return ( + init_func( + (rows, cols), + dtype=dtype, + ), + state, + ) + else: + raise NotImplementedError(f"To_order not supported: {to_order}") + + +def nvidia_transform( + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, to_order, state[1]) + else: + new_state = (state[1], to_order) + func = get_transform_func(A.dtype, from_order, to_order, transpose) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + elif ld is not None: + n = prod(shape) + dim1 = prod([shape[i] for i in ld]) + dim2 = ct.c_int32(n // dim1) + dim1 = ct.c_int32(dim1) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + return out, new_state + + +class QuantState: + """container for quantization state components to work with Params4bit and similar classes""" + + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): + self.absmax = absmax + self.shape = shape + self.code = code + self.dtype = dtype + self.blocksize = blocksize + self.quant_type = quant_type + self.offset = offset + self.state2 = state2 + self.nested = state2 is not None + + def __get_item__(self, idx): + """ + ensures compatibility with older quant state scheme with nested lists. + assumes the following layout: + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + """ + if self.nested: + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] + else: + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + None, + self.quant_type, + ] + return list_repr[idx] + + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any]) -> "QuantState": + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, mindspore.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [ + k + for k, v in qs_dict.items() + if "quant_state" in k and isinstance(v, mindspore.Tensor) + ] + if len(qs_key) == 0 and "quant_type" not in qs_dict: + raise ValueError( + "Expected packed or unpacked quant_state items, found neither" + ) + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if "nested_absmax" in qs_dict: + offset = mindspore.tensor(float(qs_dict["nested_offset"])) + state2 = cls( + absmax=qs_dict["nested_absmax"], + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"], + dtype=getattr(mindspore, qs_dict["nested_dtype"]), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"], + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"], + dtype=getattr(mindspore, qs_dict["dtype"]), + shape=( + mindspore.Size(qs_dict["shape"]) + if qs_dict["shape"] is not None + else None + ), + offset=offset, + state2=state2, + ) + return quant_state + + def as_dict(self, packed=False): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, mindspore.Tensor] for state_dict fit for safetensors saving + """ + qs_dict = { + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("mindspore."), + "shape": tuple(self.shape), + } + if self.nested: + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("mindspore."), + "nested_offset": self.offset.item(), + }, + ) + if not packed: + return qs_dict + + # packed format allows serialization of non-tensor components, critical for saving in safetensors format + qs_packed_dict = { + k: v for k, v in qs_dict.items() if isinstance(v, mindspore.Tensor) + } + non_tensor_dict = { + k: v for k, v in qs_dict.items() if not isinstance(v, mindspore.Tensor) + } + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = ( + pack_dict_to_tensor(non_tensor_dict) + ) + return qs_packed_dict + + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + np.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and np.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) + + +def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=mindspore.int8): + if context.get_context("device_target") != "GPU": + context.set_context(device_target="GPU") + + # 检查数据类型 + if A.dtype != expected_type or B.dtype != expected_type: + raise TypeError( + f"Expected {expected_type} input tensors A and B, but got {A.dtype} and {B.dtype}" + ) + + sA = A.shape + sB = B.shape + tA = transposed_A + tB = transposed_B + + correct = True + + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False + + if out is not None: + sout = out.shape + # special case common in backprop + if not correct and len(sA) == 3 and len(sB) == 3: + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): + correct = True + else: + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) + + if not correct: + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.", + ) + + return sout + + +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=mindspore.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + + m = shapeA[0] + + if dimsA == 3: + m = m * shapeA[1] + + rows = n = shapeB[0] + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return empty((0, shapeB[0]), dtype=mindspore.float16) + elif shapeA[1] == 0 and dimsA == 3: + return empty(tuple(shapeA[:2] + [shapeB[0]]), dtype=mindspore.float16) + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, "col32", "row") + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, "col32", "row" + ) + + assert dimsB != 3, "len(B.shape)==3 not supported" + assert context.get_context("device_target") == "GPU" + assert A.dtype == mindspore.int8 + assert B.dtype == mindspore.int8 + assert out.dtype == dtype + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + formatB = SB[1] + + # ptr = CUBLAS_Context.get_instance().get_context() + + k = shapeA[-1] + lda = m * 32 + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ((rows + 7) // 8) * 8 * 32 + else: + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ((rows + 31) // 32) * 32 * 32 + + ldc = m * 32 + + has_error = 0 + ptrRowScale = None + if formatB == "col_turing": + if dtype == mindspore.int32: + bnbop.cigemmlt_turing_32( + m, n, k, A, B, out, ptrRowScale, lda, ldb, ldc, has_error + ) + else: + bnbop.cigemmlt_turing_8( + m, n, k, A, B, out, ptrRowScale, lda, ldb, ldc, has_error + ) + elif formatB == "col_ampere": + if dtype == mindspore.int32: + bnbop.cigemmlt_ampere_32( + m, n, k, A, B, out, ptrRowScale, lda, ldb, ldc, has_error + ) + else: + bnbop.cigemmlt_ampere_8( + m, n, k, A, B, out, ptrRowScale, lda, ldb, ldc, has_error + ) + + if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + raise NotImplementedError( + "igemmlt not available (probably built with NO_CUBLASLT)" + ) + + if has_error: + print( + f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" + ) + raise Exception("cublasLt ran into an error!") + + return out, Sout + + +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, +): + assert A.dtype == mindspore.int32 + if bias is not None: + assert bias.dtype == mindspore.float16 + out_shape = quant_state[0] + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + if out is None: + out = empty(out_shape, dtype=mindspore.float16) + if new_row_stats is None: + new_row_stats = empty( + out_shape[0], + dtype=mindspore.float32, + ) + if new_col_stats is None: + new_col_stats = empty( + out_shape[1], + dtype=mindspore.float32, + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" + + numRows = out_shape[0] + numCols = out_shape[1] + bnbop.cdequant_mm_int32_fp16( + A, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + numRows, + numCols, + ) + + return out + + +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): + assert A.dtype == mindspore.float16 + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + + if row_stats is None: + row_stats = empty( + (rows,), + dtype=mindspore.float32, + ).fill(-50000.0) + if col_stats is None: + col_stats = empty( + (cols,), + dtype=mindspore.float32, + ).fill(-50000.0) + # if nnz_block_ptr is None and threshold > 0.0: + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = ops.zeros( + ((tiled_rows * col_tiles) + 1,), + dtype=mindspore.int32, + ) + + bnbop.cget_col_row_stats( + A, row_stats, col_stats, nnz_block_ptr, threshold, rows, cols + ) + + if threshold > 0.0: + nnz_block_ptr = nnz_block_ptr.cumsum(axis=0) + + return row_stats, col_stats, nnz_block_ptr + + +class COOSparseTensor: + def __init__(self, rows, cols, nnz, rowidx, colidx, values): + assert rowidx.dtype == mindspore.int32 + assert colidx.dtype == mindspore.int32 + assert values.dtype == mindspore.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colidx.numel() == nnz + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowidx = rowidx + self.colidx = colidx + self.values = values + + +class CSRSparseTensor: + def __init__(self, rows, cols, nnz, rowptr, colidx, values): + assert rowptr.dtype == mindspore.int32 + assert colidx.dtype == mindspore.int32 + assert values.dtype == mindspore.float16 + assert values.numel() == nnz + assert colidx.numel() == nnz + assert rowptr.numel() == rows + 1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowptr = rowptr + self.colidx = colidx + self.values = values + + +class CSCSparseTensor: + def __init__(self, rows, cols, nnz, colptr, rowidx, values): + assert colptr.dtype == mindspore.int32 + assert rowidx.dtype == mindspore.int32 + assert values.dtype == mindspore.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colptr.numel() == cols + 1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.colptr = colptr + self.rowidx = rowidx + self.values = values + + +def coo_zeros(rows, cols, nnz, dtype=mindspore.half): + rowidx = ops.zeros( + (nnz,), + dtype=mindspore.int32, + ) + colidx = ops.zeros( + (nnz,), + dtype=mindspore.int32, + ) + values = ops.zeros( + (nnz,), + dtype=dtype, + ) + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): + assert A.dtype == mindspore.half + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + + if out_col is None: + out_col = ops.zeros(A.shape, dtype=mindspore.int8) + if out_row is None: + out_row = ops.zeros(A.shape, dtype=mindspore.int8) + + coo_tensor = None + + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros( + A.shape[0], + A.shape[1], + nnz_row_ptr[-1].item(), + ) + row_idx = coo_tensor.rowidx + col_idx = coo_tensor.colidx + val = coo_tensor.values + + bnbop.cdouble_rowcol_quant( + A, + row_stats, + col_stats, + out_col, + out_row, + row_idx, + col_idx, + val, + nnz_row_ptr, + threshold, + rows, + cols, + ) + val, idx = ops.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + bnbop.cdouble_rowcol_quant( + A, + row_stats, + col_stats, + out_col, + out_row, + None, + None, + None, + None, + 0.0, + rows, + cols, + ) + else: + bnbop.cdouble_rowcol_quant( + A, + row_stats, + col_stats, + out_col, + out_row, + None, + None, + None, + None, + threshold, + rows, + cols, + ) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + +def transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, to_order, state[1], transpose + ) + else: + new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = shape[0] + dim2 = shape[1] + else: + dim1 = shape[0] * shape[1] + dim2 = shape[2] + + if to_order == "col32": + if transpose: + bnbop.ctransform_row2col32T(A, out, dim1, dim2) + else: + bnbop.ctransform_row2col32(A, out, dim1, dim2) + elif to_order == "col_turing": + if transpose: + bnbop.ctransform_row2turingT(A, out, dim1, dim2) + else: + bnbop.ctransform_row2turing(A, out, dim1, dim2) + elif to_order == "col_ampere": + if transpose: + bnbop.ctransform_row2ampereT(A, out, dim1, dim2) + else: + bnbop.ctransform_row2ampere(A, out, dim1, dim2) + elif to_order == "row": + if from_order == "col_turing": + bnbop.ctransform_turing2row(A, out, dim1, dim2) + elif from_order == "col_ampere": + bnbop.ctransform_ampere2row(A, out, dim1, dim2) + else: + raise NotImplementedError( + f"Transform function not implemented: From {from_order} to {to_order}" + ) + + return out, new_state + + +C = 127.0 + + +def vectorwise_quant(x, axis=1, quant_type="vector"): + if quant_type == "linear": + max1 = ops.abs(x).max().float() + xq = ops.round(x / max1 * 127).astype(mindspore.int8) + return xq, max1 + elif quant_type in ["vector", "row"]: + max1 = ops.amax(ops.abs(x), axis=axis, keepdims=True) + xq = ops.round(x * (C / max1)).astype(mindspore.int8) + return xq, max1 + elif quant_type == "zeropoint": + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: + dyna = 1 + qx = 255.0 / dyna + minx = x.min() + zpx = ops.round(minx * qx) + x = ops.round(qx * x - zpx) + zpx + return x, qx + elif quant_type in ["vector-zeropoint", "row-zeropoint"]: + dtype = x.dtype + x = x.float() + dyna = ops.amax(x, axis=axis, keepdims=True) - ops.amin(x, axis=axis, keepdims=True) + dyna[dyna == 0] = 1 + qx = 255.0 / dyna + minx = ops.amin(x, axis=axis, keepdims=True) + zpx = ops.round(minx * qx) + x = ops.round(qx * x - zpx) + zpx + return x, qx + elif quant_type == "truncated-vector": + absx = ops.abs(x) + max1 = ops.amax(absx, axis=axis, keepdims=True) + max1 = max1 * 0.7 + idx = absx > max1.expand_as(absx) + sign = ops.sign(x[idx]) + x[idx] = max1.expand_as(absx)[idx] * sign + xq = ops.round(x / max1 * C).astype(mindspore.int8) + return xq, max1 + else: + return None + + +def vectorwise_dequant(xq, max1, quant_type="vector"): + if quant_type == "vector": + x = (xq / C * max1).astype(mindspore.float32) + return x + else: + return None + + +def vectorwise_mm_dequant(xq, S1, S2, dtype=mindspore.half, quant_type="vector"): + if quant_type == "linear": + norm = S1 * S2 / (C * C) + # double cast needed to prevent overflows + return (xq.float() * norm).to(dtype) + elif quant_type == "zeropoint": + norm = 1.0 / (S1 * S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "row-zeropoint": + norm = 1.0 / (S1 * S2) + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= norm + else: + x *= norm + return x.to(dtype) + elif quant_type == "vector-zeropoint": + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= 1.0 / S1 + else: + x *= 1.0 / S1 + x *= 1.0 / S2.t() + return x.to(dtype) + elif quant_type == "row": + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1 * S2 / (C * C) + else: + x *= S1 * S2 / (C * C) + return x.to(dtype) + elif quant_type in ["truncated-vector", "vector"]: + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1 / C + else: + x *= S1 / C + x *= S2 / C + return x.to(dtype) + else: + return None + + +def dequant_min_max(xq, A, B, SA, SB, dtype=mindspore.half): + offset = B.float().t().sum(0) * (SA[0] + SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t() / 127 + else: + x *= SB / 127 + x *= SA[1] / 127 + x += offset + return x.to(dtype) + + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ["col_turing", "col_ampere"] + + out = ops.zeros((shapeA[0], idx.numel()), dtype=mindspore.int8) + + idx_size = idx.numel() + rows = shapeA[0] + cols = shapeA[1] + + if formatA == "col_turing": + bnbop.cextractOutliers_turing(A, idx, out, idx_size, rows, cols) + elif formatA == "col_ampere": + bnbop.cextractOutliers_ampere(A, idx, out, idx_size, rows, cols) + + return out diff --git a/mindnlp/quant/mindbnb/bitsandbytes/lib.py b/mindnlp/quant/mindbnb/bitsandbytes/lib.py new file mode 100644 index 000000000..884203474 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/lib.py @@ -0,0 +1,140 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" +# pylint: disable=E0401 +import logging +import os +from pathlib import Path +import ctypes as ct + +from mindspore import context + +from bitsandbytes.consts import ( + DYNAMIC_LIBRARY_SUFFIX, + PACKAGE_DIR, +) +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs + +logger = logging.getLogger(__name__) + + +class BNBNativeLibrary: + _lib: ct.CDLL + compiled_with_cuda = False + + def __init__(self, lib: ct.CDLL): + self._lib = lib + + def __getattr__(self, item): + return getattr(self._lib, item) + + +class CudaBNBNativeLibrary(BNBNativeLibrary): + compiled_with_cuda = True + + def __init__(self, lib: ct.CDLL): + super().__init__(lib) + lib.get_context.restype = ct.c_void_p + lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + + +def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: + """ + Get the disk path to the CUDA BNB native library specified by the + given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. + + The library is not guaranteed to exist at the returned path. + """ + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" + library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name_stem, _, library_name_ext = library_name.rpartition(".") + # `library_name_stem` will now be e.g. `libQ4M_cuda118`; + # let's remove any trailing numbers: + library_name_stem = library_name_stem.rstrip("0123456789") + # `library_name_stem` will now be e.g. `libQ4M_cuda`; + # let's tack the new version number and the original extension back on. + library_name = f"{library_name_stem}{override_value}.{library_name_ext}" + logger.warning( + f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" + "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" + "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" + "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" + "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + self.weight.data, + requires_grad=has_fp16_weights, + has_fp16_weights=has_fp16_weights, + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + scb_name = "SCB" + + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, scb_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, scb_name) + # case 3: SCB is in self.state, weight layout reordered after first forward() + layout_reordered = self.state.CxB is not None + + key_name = prefix + f"{scb_name}" + format_name = prefix + "weight_format" + + if not self.state.has_fp16_weights: + if param_from_weight is not None: + destination[key_name] = ( + param_from_weight if keep_vars else param_from_weight.detach() + ) + destination[format_name] = mindspore.tensor(0, dtype=mindspore.uint8) + elif param_from_state is not None and not layout_reordered: + destination[key_name] = ( + param_from_state if keep_vars else param_from_state.detach() + ) + destination[format_name] = mindspore.tensor(0, dtype=mindspore.uint8) + elif param_from_state is not None: + destination[key_name] = ( + param_from_state if keep_vars else param_from_state.detach() + ) + weights_format = self.state.formatB + # At this point `weights_format` is an str + if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: + raise ValueError(f"Unrecognized weights format {weights_format}") + + weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] + + destination[format_name] = mindspore.tensor( + weights_format, dtype=mindspore.uint8 + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + unexpected_copy = list(unexpected_keys) + + for key in unexpected_copy: + input_name = key[len(prefix) :] + if input_name == "SCB": + if self.weight.SCB is None: + # buffers not yet initialized, can't access them directly without quantizing first + raise RuntimeError( + "Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()", + ) + + input_param = state_dict[key] + self.weight.SCB.copy_(input_param) + + if self.state.SCB is not None: + self.state.SCB = self.weight.SCB + + unexpected_keys.remove(key) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def quant( + self, + ): + for key, param in self.parameters_dict().items(): + if param is None: + continue + if key == "weight": + self.cuda(self.weight) + return self + + def cuda(self, param): + if param.has_fp16_weights: + param.data.astype(mindspore.float16) + return self + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = param.data.astype(mindspore.float16) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + param.assign_value(CB) + param.CB = CB + param.SCB = SCB + + return self + + def forward(self, x: mindspore.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias = mindspore.Parameter( + self.bias.astype(x.dtype), requires_grad=self.bias.requires_grad + ) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.assign_value(self.state.CxB) + return out diff --git a/mindnlp/quant/mindbnb/bitsandbytes/utils.py b/mindnlp/quant/mindbnb/bitsandbytes/utils.py new file mode 100644 index 000000000..2dc86a286 --- /dev/null +++ b/mindnlp/quant/mindbnb/bitsandbytes/utils.py @@ -0,0 +1,208 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + utils +''' +import json +import shlex +import subprocess +from typing import Tuple + +import mindspore +from mindspore import ops + +from mindnlp.core import nn + + +def outlier_hook(module, input): + assert isinstance(module, nn.Linear) + tracer = OutlierTracer.get_instance() + hvalue = tracer.get_hvalue(module.weight) + if hvalue not in tracer.hvalue2outlier_idx: + outlier_idx = find_outlier_dims(module.weight) + tracer.outliers.append(outlier_idx) + tracer.hvalues.append(hvalue) + if len(tracer.outliers) > 1: + # assign the current layer the outlier idx found from the weight + # of the previous linear layer + if tracer.outliers[-1].numel() > 0: + assert tracer.outliers[-1].max() < module.weight.shape[1] + tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] + + else: + # first layer, we cannot use the weight for outlier detection + # we follow a mixed approach: + # (1) zscore test of std of hidden dimension + # (2) magnitude > 6 test + merged = input[0].view(-1, input[0].shape[-1]) + # (1) zscore test of std of hidden dimension + outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) + # (2) magnitude > 6 test + dims = (ops.abs([0]) > 6).sum(dim=list(range(len(input[0].shape) - 1))) + outlier_idx2 = ops.nonzero(dims > 0)[0] + outlier_idx = ops.cat([outlier_idx, outlier_idx2]).unique() + tracer.hvalue2outlier_idx[hvalue] = outlier_idx + else: + for hook in tracer.hooks: + hook.remove() + + +class OutlierTracer: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self, model): + self.last_w = None + self.current_outlier_dims = None + self.hvalues = [] + self.outliers = [] + self.hvalue2outlier_idx = {} + self.initialized = True + self.hooks = [] + + for n, m in model.named_modules(): + if isinstance(m, nn.Linear): + self.hooks.append(m.register_forward_pre_hook(outlier_hook)) + + def is_initialized(self): + return getattr(self, "initialized", False) + + def get_hvalue(self, weight): + return weight.data.storage().data_ptr() + + def get_outliers(self, weight): + if not self.is_initialized(): + print("Outlier tracer is not initialized...") + return None + hvalue = self.get_hvalue(weight) + if hvalue in self.hvalue2outlier_idx: + return self.hvalue2outlier_idx[hvalue] + else: + return None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + + +def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): + if rdm: + return ops.randint( + 0, weight.shape[1], size=(topk,) + ).long() + + m = weight.mean(reduction_dim) + mm = m.mean() + mstd = m.std() + zm = (m - mm) / mstd + + std = weight.std(reduction_dim) + stdm = std.mean() + stdstd = std.std() + + zstd = (std - stdm) / stdstd + + if topk is not None: + val, idx = ops.topk(std.abs(), k=topk, dim=0) + else: + idx = ops.nonzero(zstd > zscore)[0] + + return idx + + +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate(), + ) + + std_out, std_err = execute_and_return_decoded_std_streams(command_string) + return std_out, std_err + + +def replace_linear( + model, + linear_replacement, + skip_modules=("lm_head",), + copy_weights=False, + post_processing_function=None, +): + + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear( + module, + linear_replacement, + skip_modules, + copy_weights, + post_processing_function, + ) + + if isinstance(module, nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: + func(module) + return model + + +def pack_dict_to_tensor(source_dict): + json_str = json.dumps(source_dict) + json_bytes = json_str.encode("utf-8") + tensor_data = mindspore.tensor(list(json_bytes), dtype=mindspore.uint8) + + return tensor_data + + +def unpack_tensor_to_dict(tensor_data): + json_bytes = bytes(tensor_data.cpu().numpy()) + json_str = json_bytes.decode("utf-8") + unpacked_dict = json.loads(json_str) + + return unpacked_dict + + +LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = { + "row": 0, + "col32": 1, + "col_turing": 2, + "col_ampere": 3, +} +INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = { + val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items() +} diff --git a/mindnlp/quant/mindbnb/csrc/common.cpp b/mindnlp/quant/mindbnb/csrc/common.cpp new file mode 100644 index 000000000..0a9601689 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/common.cpp @@ -0,0 +1,35 @@ +#include +#include + +void quantize_block(const quantize_block_args& args) { + // 1. find absmax in block + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + // 4. check minimal distance + // 5. store index + + // 1. find absmax in block + float absmax_block = -FLT_MAX; + for (long long i = args.block_idx; i < args.block_end; i++) + absmax_block = fmax(absmax_block, fabs(args.A[i])); + + args.absmax[args.block_idx / args.blocksize] = absmax_block; + + for (long long i = args.block_idx; i < args.block_end; i++) { + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + float normed_value = args.A[i] / absmax_block; + long long idx = args.bin_searcher->scalar(normed_value); + + // 4. check minimal distance + // The binary search returns always the value to the left, which might not be the closest value + if (idx < 255) { + float dist_left = fabs(normed_value - (args.code[idx])); + float dist_right = fabs(normed_value - (args.code[idx + 1])); + if (dist_right < dist_left) { idx += 1; } + } + + // 5. store index + args.out[i] = (unsigned char) idx; + } +} diff --git a/mindnlp/quant/mindbnb/csrc/common.h b/mindnlp/quant/mindbnb/csrc/common.h new file mode 100644 index 000000000..e513f2875 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/common.h @@ -0,0 +1,25 @@ +#include + +#ifndef common +#define common + +using namespace BinSearch; + +#define BLOCK_SIZE 16384 + +struct quantize_block_args { + BinAlgo *bin_searcher; + float *code; + float *A; + float *absmax; + unsigned char *out; + long long block_end; + long long block_idx; + long long threadidx; + long long blocksize; +}; + + +void quantize_block(const quantize_block_args& args); + +#endif diff --git a/mindnlp/quant/mindbnb/csrc/cpu_ops.cpp b/mindnlp/quant/mindbnb/csrc/cpu_ops.cpp new file mode 100644 index 000000000..e67135360 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/cpu_ops.cpp @@ -0,0 +1,63 @@ +#include +#include +#include + +using namespace BinSearch; + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / blocksize]; + } +} + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) +{ + + // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below + code[0] = -1.0f; + + long long num_blocks = n / blocksize; + num_blocks += n % blocksize == 0 ? 0 : 1; + + const uint32 elements_code = 256; + BinAlgo bin_searcher(code, elements_code); + + int thread_wave_size = 256; + // we chunk the threads into waves of 256 since the max limit is + // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) + { + long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; + std::vector threads(valid_chunks); + std::vector args(valid_chunks); + + int chunks_processed = 0; + for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) + { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args& arg = args[chunks_processed]; + arg.bin_searcher = &bin_searcher; + arg.code = code; + arg.A = A; + arg.absmax = absmax; + arg.out = out; + arg.block_end = block_end; + arg.block_idx = block_idx; + arg.threadidx = block_idx / blocksize; + arg.blocksize = blocksize; + + threads[chunks_processed] = std::thread([arg] { quantize_block(arg); }); + chunks_processed += 1; + if(chunks_processed == valid_chunks){ break; } + } + + for (int i = 0; i < valid_chunks; i++) + threads[i].join(); + } + +} diff --git a/mindnlp/quant/mindbnb/csrc/cpu_ops.h b/mindnlp/quant/mindbnb/csrc/cpu_ops.h new file mode 100644 index 000000000..2ddf81e49 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/cpu_ops.h @@ -0,0 +1,10 @@ +#ifndef BITSANDBYTES_CPU_OPS_H +#define BITSANDBYTES_CPU_OPS_H + +#include +#include + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); + +#endif diff --git a/mindnlp/quant/mindbnb/csrc/kernels.cu b/mindnlp/quant/mindbnb/csrc/kernels.cu new file mode 100644 index 000000000..1bf753176 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/kernels.cu @@ -0,0 +1,4076 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ float atomicMax(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS( + reinterpret_cast(address), assumed, + __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} + +__device__ float atomicMin(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS( + reinterpret_cast(address), assumed, + __float_as_int(fminf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} + +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +template +__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) +{ + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + typedef cub::BlockLoad LoadT; + __shared__ typename LoadT::TempStorage loadt; + + const int warp_idx = threadIdx.x/32; + const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; + __shared__ float smem_max_values[8*BLOCK_SIZE/32]; + + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); + #pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8*threadIdx.x + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8*threadIdx.x + i; + } + } + + float warp_max; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); + warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + smem_max_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + __syncwarp(); + } + + if(threadIdx.x % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef cub::BlockRadixSort BlockRadixSort; + typedef cub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + __syncthreads(); + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + + if(threadIdx.x == 0) + smem_absmax_value[0] = local_abs_max; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + __syncwarp(); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef cub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef cub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + typedef cub::BlockReduce BlockReduce2; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockReduce BlockRowReduce; + typedef cub::BlockReduce BlockRowSum; + typedef cub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + // smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll TILE_ROWS + for (int j = 0; j < TILE_ROWS; j++) { + smem_row_nnz_values[j] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + __syncthreads(); + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arrangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if (val < smem_row_absmax_values[threadIdx.x + (j * THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + __shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef cub::BlockLoad LoadInt32; + typedef cub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + __shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + // no block loads for rows for now -- keep it simple + for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + } + __syncthreads(); + + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef cub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef cub::BlockExchange BlockExchange; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consecutive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happens every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ + int local_colidx = idx[blockIdx.x]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggregate files of C into shared memory block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, wmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for(int i = threadIdx.x; i < 16; i++) + quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) + { + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if __CUDA_ARCH__ >= 800 + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, __nv_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8) + + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/mindnlp/quant/mindbnb/csrc/kernels.cuh b/mindnlp/quant/mindbnb/csrc/kernels.cuh new file mode 100644 index 000000000..a7fe3d700 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/kernels.cuh @@ -0,0 +1,132 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/mindnlp/quant/mindbnb/csrc/mps_kernels.metal b/mindnlp/quant/mindbnb/csrc/mps_kernels.metal new file mode 100644 index 000000000..63b3bf78c --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/mps_kernels.metal @@ -0,0 +1,117 @@ +#include +using namespace metal; + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +template +static unsigned char quantize_scalar( + float rand, + device float* code, + float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = code[pivot]; + } + + if(upper_pivot == 255) + upper = code[upper_pivot]; + if(lower_pivot == 0) + lower = code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabs(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabs(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +kernel void quantize(device float* code [[buffer(0)]], + device float* A [[buffer(1)]], + device uchar* out [[buffer(2)]], + constant uint& n [[buffer(3)]], + uint id [[thread_position_in_grid]]) { + const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; + const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); + + float vals[NUM]; + uchar qvals[NUM]; + + for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + vals[j] = A[i + j]; + } + + for (uint j = 0; j < valid_items; j++) { + qvals[j] = quantize_scalar(0.0f, code, vals[j]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint j = 0; j < valid_items; j++) { + out[i + j] = qvals[j]; + } + } +} diff --git a/mindnlp/quant/mindbnb/csrc/mps_ops.h b/mindnlp/quant/mindbnb/csrc/mps_ops.h new file mode 100644 index 000000000..e69de29bb diff --git a/mindnlp/quant/mindbnb/csrc/mps_ops.mm b/mindnlp/quant/mindbnb/csrc/mps_ops.mm new file mode 100644 index 000000000..d198b3552 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/mps_ops.mm @@ -0,0 +1,67 @@ +#import + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +static inline MPSGraph* get_graph() +{ + static MPSGraph* cur = nil; + if(!cur) { + cur = [[MPSGraph alloc] init]; + } + return cur; +} + +static inline id get_device() +{ + NSError *error = nil; + static id device = nil; + if(!device) { + device = MTLCreateSystemDefaultDevice(); + } + if(!device) { + NSLog(@"Failed to get MPS device"); + abort(); + } + return device; +} + +static inline id get_library() +{ + NSError *error = nil; + static id library = nil; + if(!library) { + library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + } + if(!library) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + return library; +} + +/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 dataType:MPSDataTypeInt8 axis:0 name:@"out"]; + return out; +}*/ + + +// MPSGraph function for quantize +extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) +{ + id device = get_device(); + id library = get_library(); + static id kernel = nil; + if(!kernel) { + kernel = [library newFunctionWithName:@"quantize"]; + if(!kernel) { + NSLog(@"Failed to load bitsandbytes.metallib"); + abort(); + } + } + NSLog(@"Not implemented"); + return nil; +} diff --git a/mindnlp/quant/mindbnb/csrc/ops.cu b/mindnlp/quant/mindbnb/csrc/ops.cu new file mode 100644 index 000000000..44dce914c --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/ops.cu @@ -0,0 +1,863 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +#define ERR_NOT_IMPLEMENTED 100 + + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); + kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kQuantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kDequantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 64) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) +{ + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + status = cublasGemmEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, + C, CUDA_R_32I, ldc, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = cublasGemmStridedBatchedEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, + C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +#ifdef NO_CUBLASLT +#else +template cublasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return CUBLASLT_ORDER_ROW; + break; + case COL: + return CUBLASLT_ORDER_COL; + break; + case COL32: + return CUBLASLT_ORDER_COL32; + break; + case COL_TURING: + return CUBLASLT_ORDER_COL4_4R2_8C; + break; + case COL_AMPERE: + return CUBLASLT_ORDER_COL32_2R_4R4; + break; + default: + break; + } + + return CUBLASLT_ORDER_ROW; +} + +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +#endif + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ +#ifdef NO_CUBLASLT +#else + cublasLtOrder_t orderA = get_order(); + cublasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + + cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; + cublasLtMatrixTransformDesc_t A2Out_desc = NULL; + cublasOperation_t opTranspose = CUBLAS_OP_T; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + + if(DTYPE == 8) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut)); + } + else if(DTYPE == 32) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut)); + } + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F)); + + if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); + + if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); + if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); + if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif +} + +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ +#ifdef NO_CUBLASLT + return ERR_NOT_IMPLEMENTED; +#else + int has_error = 0; + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasOperation_t opT = CUBLAS_OP_T; + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); + + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(FORMATB == COL_TURING) + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + if(DTYPE_OUT == 32) + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + } + + + if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_CUBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS * STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + if(nnz_threshold == 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_CUBLASLT +#else + + cusparseSpMatDescr_t descA; + cusparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); + // Create dense matrix C + CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_CUSPARSE( cusparseSpMM_bufferSize( + handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_CUSPARSE( cusparseSpMM(handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int num_blocks = idx_size; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+3)/4; + + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 1, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, __nv_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/mindnlp/quant/mindbnb/csrc/ops.cuh b/mindnlp/quant/mindbnb/csrc/ops.cuh new file mode 100644 index 000000000..8b9a4f449 --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/ops.cuh @@ -0,0 +1,202 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +#define CUDA_CHECK_RETURN(value) { \ + cudaError_t _m_cudaStat = value; \ + if (_m_cudaStat != cudaSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + +#define THREADS_PER_BLOCKS (512) + +#define CHECK_CUSPARSE(value) { \ + cusparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define THREADS_PER_BLOCKS (512) + + +inline void checkCudaStatus(cudaError_t status) { + if (status != cudaSuccess) { + printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); + throw std::logic_error("cuda API failed"); + } +} + +inline int checkCublasStatus(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + cublasHandle_t m_handle; + + Context() + { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + cublasLtHandle_t m_handle; + + ContextLt() + { + cublasLtHandle_t handle; + cublasLtCreate(&handle); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + cusparseHandle_t m_handle; + + ContextCusparse() + { + cusparseHandle_t handle; + cusparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, + int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/mindnlp/quant/mindbnb/csrc/pythonInterface.cpp b/mindnlp/quant/mindbnb/csrc/pythonInterface.cpp new file mode 100644 index 000000000..ce30b672f --- /dev/null +++ b/mindnlp/quant/mindbnb/csrc/pythonInterface.cpp @@ -0,0 +1,933 @@ +// Copyright 2024 Huawei Technologies Co., Ltd + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ + +#if BUILD_CUDA +#include +#endif +#if BUILD_MPS +// #include +#endif +#include + +// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. +// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to +// maintain all that boilerplate +//=================================================================================== +// UNMANGLED CALLS +//=================================================================================== + +#if BUILD_CUDA + +Context *CUBLAS_CONTEXT = nullptr; + +void estimateQuantiles_fp32(float *A, float *code, float offset, int n) { estimateQuantiles(A, code, offset, n); } +void estimateQuantiles_fp16(half *A, float *code, float offset, int n) { estimateQuantiles(A, code, offset, n); } + +// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +void gemm_host_fp16(int M, int N, int K, half *A, half *B, half *out, int lda, int ldb, int ldc) +{ + gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); +} + +void gemm_4bit_inference(int m, int n, int k, half *A, unsigned char *B, float *absmax, half *out, int lda, int ldb, int ldc, int blocksize) +{ + gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half *A, unsigned char *B, float *absmax, float *datatype, half *out, int lda, int ldb, int ldc, int blocksize) +{ + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); +} + +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 *A, unsigned char *B, float *absmax, float *datatype, __nv_bfloat16 *out, int lda, int ldb, int ldc, int blocksize) +{ + gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); +} + +void gemm_4bit_inference_naive_fp32(int m, int n, int k, float *A, unsigned char *B, float *absmax, float *datatype, float *out, int lda, int ldb, int ldc, int blocksize) +{ + gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); +} + +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void fname##_##type_name(ctype *A, ctype *B, ctype value, long n) { func(A, B, value, n); } + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + +#define MAKE_FUNC32(fname, oname, gtype, gbits) \ + void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ + float *state1, float *state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ + { \ + optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); \ + } + +MAKE_FUNC32(momentum, MOMENTUM, float, 32) +MAKE_FUNC32(momentum, MOMENTUM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) +MAKE_FUNC32(rmsprop, RMSPROP, float, 32) +MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, fp32) +MAKE_FUNC32(lion, LION, half, fp16) +MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) +MAKE_FUNC32(adagrad, ADAGRAD, float, 32) +MAKE_FUNC32(adagrad, ADAGRAD, half, 16) + +#define MAKE_FUNC8(fname, oname, gtype, gbits) \ + void fname##_static_8bit_grad_##gbits(gtype *p, gtype *g, unsigned char *state1, unsigned char *state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float *quantiles1, float *quantiles2, \ + float *max1, float *max2, float *new_max1, float *new_max2, \ + float weight_decay, float gnorm_scale, int n) \ + { \ + optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + } + +MAKE_FUNC8(adam, ADAM, float, 32) +MAKE_FUNC8(adam, ADAM, half, 16) +MAKE_FUNC8(momentum, MOMENTUM, float, 32) +MAKE_FUNC8(momentum, MOMENTUM, half, 16) +MAKE_FUNC8(rmsprop, RMSPROP, float, 32) +MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) + +#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ + void fname##_8bit_blockwise_grad_##gbits(gtype *p, gtype *g, \ + unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr, \ + float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ + { \ + optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); \ + } + +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) + +void percentileClipping_g32(float *g, float *gnorm_vec, int step, const int n) { percentileClipping(g, gnorm_vec, step, n); } +void percentileClipping_g16(half *g, float *gnorm_vec, int step, const int n) { percentileClipping(g, gnorm_vec, step, n); } + +void quantizeBlockwise_fp16(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_fp4(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_nf4(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void quantizeBlockwise_bf16(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise<__nv_bfloat16, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_fp4(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise<__nv_bfloat16, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_bf16_nf4(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise<__nv_bfloat16, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void quantizeBlockwise_fp32(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_nf4(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } + +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } + +void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); } + +#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ + void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ + { \ + transform(ltHandle, A, out, dim1, dim2); \ + } + +MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); +MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); +MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32); +MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); +MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); +MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); + +void transform_row2col32(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } +void transform_row2col32T(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } +void transform_row2turing(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } +void transform_row2turingT(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } +void transform_row2ampere(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } +void transform_row2ampereT(char *A, char *out, int rows, int cols) { transformRowToFormat(A, out, rows, cols); } + +void extractOutliers_turing(char *A, int *idx, char *out, int idx_size, int rows, int cols) { extractOutliers(A, idx, out, idx_size, rows, cols); } +void extractOutliers_ampere(char *A, int *idx, char *out, int idx_size, int rows, int cols) { extractOutliers(A, idx, out, idx_size, rows, cols); } + +int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); +} + +void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); +} + +void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); +} +#endif + +extern "C" +{ +#if BUILD_CUDA + void cestimate_quantiles_fp32(float *A, float *code, float offset, int n) { estimateQuantiles_fp32(A, code, offset, n); } + void cestimate_quantiles_fp16(half *A, float *code, float offset, int n) { estimateQuantiles_fp16(A, code, offset, n); } + void cquantize(float *code, float *A, unsigned char *out, int n) { quantize(code, A, out, n); } + void cdequantize(float *code, unsigned char *A, float *out, int n) { dequantize(code, A, out, n); } + + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n) { dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_fp16(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float *code, half *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_fp32(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float *code, float *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n) { dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_bf16(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_fp4(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_bf16_nf4(float *code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n) { quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n) { dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + +#define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ + float *state1, float *state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ + { \ + name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); \ + } + + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) + MAKE_CFUNC32(momentum, float, 32) + MAKE_CFUNC32(momentum, half, 16) + MAKE_CFUNC32(rmsprop, float, 32) + MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(lion, float, fp32) + MAKE_CFUNC32(lion, half, fp16) + MAKE_CFUNC32(lion, __nv_bfloat16, bf16) + MAKE_CFUNC32(adagrad, float, 32) + MAKE_CFUNC32(adagrad, half, 16) + +#define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_grad_##gbits(gtype *p, gtype *g, unsigned char *state1, unsigned char *state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float *quantiles1, float *quantiles2, \ + float *max1, float *max2, float *new_max1, float *new_max2, \ + float weight_decay, float gnorm_scale, int n) \ + { \ + name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + } + + MAKE_CFUNC8(adam, float, 32) + MAKE_CFUNC8(adam, half, 16) + MAKE_CFUNC8(momentum, float, 32) + MAKE_CFUNC8(momentum, half, 16) + MAKE_CFUNC8(rmsprop, float, 32) + MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(lion, float, 32) + MAKE_CFUNC8(lion, half, 16) + +#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ + void c##fname##_8bit_blockwise_grad_##gbits(gtype *p, gtype *g, \ + unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr, \ + float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ + { \ + fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); \ + } + + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, half, fp16) + MAKE_CBLOCKWISE8(lion, LION, float, fp32) + MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) + + void cpercentile_clipping_g32(float *g, float *gnorm_vec, int step, const int n) { percentileClipping_g32(g, gnorm_vec, step, n); } + void cpercentile_clipping_g16(half *g, float *gnorm_vec, int step, const int n) { percentileClipping_g16(g, gnorm_vec, step, n); } + void chistogram_scatter_add_2d(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) { histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } + + void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) + { + gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); + } + void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long strideA, long strideB, long strideC, int batchCount) + { + strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); + } + + void get_context() + { + if (CUBLAS_CONTEXT == nullptr) + CUBLAS_CONTEXT = new Context(); + } + ContextCusparse *get_cusparse() { return new ContextCusparse(); } + + int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_turing_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + //{ (cublasLtHandle_t)context->m_handle; return 0; } + //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_turing_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + + int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_turing_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + + int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_ampere_32((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + + int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_ampere_8_rowscale((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + + int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { + return igemmlt_ampere_8((cublasLtHandle_t)context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + } + +#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ + void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ + { \ + transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t)context->m_handle, A, out, dim1, dim2); \ + } + + MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) + MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32) + MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8) + MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) + + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float *newRowStats, float *newcolStats, half *bias, int numRows, int numCols) + { + dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); + } + void cget_col_row_stats(half *A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) + { + getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); + } + + void cdouble_rowcol_quant(half *A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) + { + doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); + } + + void ctransform_row2col32(char *A, char *out, int rows, int cols) + { + transform_row2col32(A, out, rows, cols); + } + + void ctransform_row2col32T(char *A, char *out, int rows, int cols) + { + transform_row2col32T(A, out, rows, cols); + } + + void ctransform_row2turing(char *A, char *out, int rows, int cols) + { + transform_row2turing(A, out, rows, cols); + } + + void ctransform_row2turingT(char *A, char *out, int rows, int cols) + { + transform_row2turingT(A, out, rows, cols); + } + + void ctransform_row2ampere(char *A, char *out, int rows, int cols) + { + transform_row2ampere(A, out, rows, cols); + } + + void ctransform_row2ampereT(char *A, char *out, int rows, int cols) + { + transform_row2ampereT(A, out, rows, cols); + } + + void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half *C, bool transposed_B) + { + spmm_coo((cusparseHandle_t)context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); + } + + void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { + spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); + } + + void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { + spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); + } + + void cextractOutliers_turing(char *A, int *idx, char *out, int idx_size, int rows, int cols) { extractOutliers_turing(A, idx, out, idx_size, rows, cols); } + void cextractOutliers_ampere(char *A, int *idx, char *out, int idx_size, int rows, int cols) { extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + + // void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_host_fp16(int M, int N, int K, half *A, half *B, half *out, int lda, int ldb, int ldc) + { + gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); + } + + void cgemm_4bit_inference(int m, int n, int k, half *A, unsigned char *B, float *absmax, half *out, int lda, int ldb, int ldc, int blocksize) + { + gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + } + + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + + int hasPrefetch = 0; + CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&hasPrefetch, cudaDevAttrConcurrentManagedAccess, device)); // 40ns overhead + if (hasPrefetch == 0) + return; + + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + +#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n) { fname##_##type_name(A, B, value, n); } + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half *A, unsigned char *B, float *absmax, float *datatype, half *out, int lda, int ldb, int ldc, int blocksize) + { + gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + } + + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 *A, unsigned char *B, float *absmax, float *datatype, __nv_bfloat16 *out, int lda, int ldb, int ldc, int blocksize) + { + gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + } + + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float *A, unsigned char *B, float *absmax, float *datatype, float *out, int lda, int ldb, int ldc, int blocksize) + { + gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + } + +#endif + + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) { quantize_cpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { dequantize_cpu(code, A, absmax, out, blocksize, n); } +} + +extern "C" +{ +#ifdef BUILD_CUDA + + int custom_cget_col_row_stats(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 8) + return 1; + + half *A = static_cast(params[0]); + float *rowStats = static_cast(params[1]); + float *colStats = static_cast(params[2]); + + void *nnz_threshold_ptr, *rows_ptr, *cols_ptr; + cudaMallocHost(&nnz_threshold_ptr, sizeof(float)); + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + + int *nnz_count_row = static_cast(params[3]); + cudaMemcpy(nnz_threshold_ptr, params[4], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(rows_ptr, params[5], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[6], sizeof(int), cudaMemcpyDeviceToHost); + + auto nnz_threshold = *static_cast(nnz_threshold_ptr); + auto rows = *static_cast(rows_ptr); + auto cols = *static_cast(cols_ptr); + + getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); + + return 0; + } + + int custom_cdouble_rowcol_quant(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 13) + return 1; + + half *A = static_cast(params[0]); + float *rowStats = static_cast(params[1]); + float *colStats = static_cast(params[2]); + char *out_col_normed = static_cast(params[3]); + char *out_row_normed = static_cast(params[4]); + + void *threshold_ptr, *rows_ptr, *cols_ptr; + cudaMallocHost(&threshold_ptr, sizeof(float)); + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + + int *rowidx = static_cast(params[5]); + int *colidx = static_cast(params[6]); + half *val = static_cast(params[7]); + int *nnz_row_ptr = static_cast(params[8]); + cudaMemcpy(threshold_ptr, params[9], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(rows_ptr, params[10], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[11], sizeof(int), cudaMemcpyDeviceToHost); + + auto threshold = *static_cast(threshold_ptr); + auto rows = *static_cast(rows_ptr); + auto cols = *static_cast(cols_ptr); + + doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); + + return 0; + } + + int custom_ctransform_row2col32T(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2col32T(A, out, rows, cols); + + return 0; + } + + int custom_ctransform_row2col32(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2col32(A, out, rows, cols); + + return 0; + } + + int custom_ctransform_row2turingT(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2turingT(A, out, rows, cols); + + return 0; + } + + int custom_ctransform_row2turing(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2turing(A, out, rows, cols); + + return 0; + } + + int custom_ctransform_row2ampereT(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2ampereT(A, out, rows, cols); + + return 0; + } + + int custom_ctransform_row2ampere(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 5) + return 1; + + char *A = static_cast(params[0]); + char *out = static_cast(params[1]); + void *rows_ptr, *cols_ptr; + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(rows_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + transform_row2ampere(A, out, rows, cols); + + return 0; + } + + + int custom_cextractOutliers_turing (int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 7) + return 1; + + char *A = static_cast(params[0]); + int *idx = static_cast(params[1]); + char *out = static_cast(params[2]); + void *idx_size_ptr, *rows_ptr, *cols_ptr; + cudaMallocHost(&idx_size_ptr, sizeof(int)); + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(idx_size_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(rows_ptr, params[4], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[5], sizeof(int), cudaMemcpyDeviceToHost); + int idx_size = *static_cast(idx_size_ptr); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + extractOutliers_turing(A, idx, out, idx_size, rows, cols); + + return 0; + } + + int custom_cextractOutliers_ampere (int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 7) + return 1; + + char *A = static_cast(params[0]); + int *idx = static_cast(params[1]); + char *out = static_cast(params[2]); + void *idx_size_ptr, *rows_ptr, *cols_ptr; + cudaMallocHost(&idx_size_ptr, sizeof(int)); + cudaMallocHost(&rows_ptr, sizeof(int)); + cudaMallocHost(&cols_ptr, sizeof(int)); + cudaMemcpy(idx_size_ptr, params[3], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(rows_ptr, params[4], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(cols_ptr, params[5], sizeof(int), cudaMemcpyDeviceToHost); + int idx_size = *static_cast(idx_size_ptr); + int rows = *static_cast(rows_ptr); + int cols = *static_cast(cols_ptr); + + extractOutliers_ampere(A, idx, out, idx_size, rows, cols); + + return 0; + } + + int custom_cigemmlt_turing_32(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 12) + return 1; + + get_context(); + void *m_ptr, *n_ptr, *k_ptr, *lda_ptr, *ldb_ptr, *ldc_ptr; + cudaMallocHost(&m_ptr, sizeof(int)); + cudaMallocHost(&n_ptr, sizeof(int)); + cudaMallocHost(&k_ptr, sizeof(int)); + cudaMemcpy(m_ptr, params[0], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(n_ptr, params[1], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(k_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + int m = *static_cast(m_ptr); + int n = *static_cast(n_ptr); + int k = *static_cast(k_ptr); + int8_t *A = static_cast(params[3]); + int8_t *B = static_cast(params[4]); + void *C = params[5]; + float *row_scale = static_cast(params[6]); + cudaMallocHost(&lda_ptr, sizeof(int)); + cudaMallocHost(&ldb_ptr, sizeof(int)); + cudaMallocHost(&ldc_ptr, sizeof(int)); + cudaMemcpy(lda_ptr, params[7], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldb_ptr, params[8], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldc_ptr, params[9], sizeof(int), cudaMemcpyDeviceToHost); + int lda = *static_cast(lda_ptr); + int ldb = *static_cast(ldb_ptr); + int ldc = *static_cast(ldc_ptr); + int has_error = igemmlt_turing_32((cublasLtHandle_t)CUBLAS_CONTEXT->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + params[10] = static_cast(reinterpret_cast(static_cast(has_error))); + + return 0; + } + + int custom_cigemmlt_turing_8(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 12) + return 1; + + get_context(); + void *m_ptr, *n_ptr, *k_ptr, *lda_ptr, *ldb_ptr, *ldc_ptr; + cudaMallocHost(&m_ptr, sizeof(int)); + cudaMallocHost(&n_ptr, sizeof(int)); + cudaMallocHost(&k_ptr, sizeof(int)); + cudaMemcpy(m_ptr, params[0], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(n_ptr, params[1], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(k_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + int m = *static_cast(m_ptr); + int n = *static_cast(n_ptr); + int k = *static_cast(k_ptr); + int8_t *A = static_cast(params[3]); + int8_t *B = static_cast(params[4]); + void *C = params[5]; + float *row_scale = static_cast(params[6]); + cudaMallocHost(&lda_ptr, sizeof(int)); + cudaMallocHost(&ldb_ptr, sizeof(int)); + cudaMallocHost(&ldc_ptr, sizeof(int)); + cudaMemcpy(lda_ptr, params[7], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldb_ptr, params[8], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldc_ptr, params[9], sizeof(int), cudaMemcpyDeviceToHost); + int lda = *static_cast(lda_ptr); + int ldb = *static_cast(ldb_ptr); + int ldc = *static_cast(ldc_ptr); + int has_error = igemmlt_turing_8((cublasLtHandle_t)CUBLAS_CONTEXT->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + params[10] = static_cast(reinterpret_cast(static_cast(has_error))); + + return 0; + } + + int custom_cigemmlt_ampere_32(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 12) + return 1; + + get_context(); + void *m_ptr, *n_ptr, *k_ptr, *lda_ptr, *ldb_ptr, *ldc_ptr; + cudaMallocHost(&m_ptr, sizeof(int)); + cudaMallocHost(&n_ptr, sizeof(int)); + cudaMallocHost(&k_ptr, sizeof(int)); + cudaMemcpy(m_ptr, params[0], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(n_ptr, params[1], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(k_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + int m = *static_cast(m_ptr); + int n = *static_cast(n_ptr); + int k = *static_cast(k_ptr); + int8_t *A = static_cast(params[3]); + int8_t *B = static_cast(params[4]); + void *C = params[5]; + float *row_scale = static_cast(params[6]); + cudaMallocHost(&lda_ptr, sizeof(int)); + cudaMallocHost(&ldb_ptr, sizeof(int)); + cudaMallocHost(&ldc_ptr, sizeof(int)); + cudaMemcpy(lda_ptr, params[7], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldb_ptr, params[8], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldc_ptr, params[9], sizeof(int), cudaMemcpyDeviceToHost); + int lda = *static_cast(lda_ptr); + int ldb = *static_cast(ldb_ptr); + int ldc = *static_cast(ldc_ptr); + int has_error = igemmlt_ampere_32((cublasLtHandle_t)CUBLAS_CONTEXT->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + params[10] = static_cast(reinterpret_cast(static_cast(has_error))); + + return 0; + } + + int custom_cigemmlt_ampere_8(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 12) + return 1; + + get_context(); + void *m_ptr, *n_ptr, *k_ptr, *lda_ptr, *ldb_ptr, *ldc_ptr; + cudaMallocHost(&m_ptr, sizeof(int)); + cudaMallocHost(&n_ptr, sizeof(int)); + cudaMallocHost(&k_ptr, sizeof(int)); + cudaMemcpy(m_ptr, params[0], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(n_ptr, params[1], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(k_ptr, params[2], sizeof(int), cudaMemcpyDeviceToHost); + int m = *static_cast(m_ptr); + int n = *static_cast(n_ptr); + int k = *static_cast(k_ptr); + int8_t *A = static_cast(params[3]); + int8_t *B = static_cast(params[4]); + void *C = params[5]; + float *row_scale = static_cast(params[6]); + cudaMallocHost(&lda_ptr, sizeof(int)); + cudaMallocHost(&ldb_ptr, sizeof(int)); + cudaMallocHost(&ldc_ptr, sizeof(int)); + cudaMemcpy(lda_ptr, params[7], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldb_ptr, params[8], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(ldc_ptr, params[9], sizeof(int), cudaMemcpyDeviceToHost); + int lda = *static_cast(lda_ptr); + int ldb = *static_cast(ldb_ptr); + int ldc = *static_cast(ldc_ptr); + int has_error = igemmlt_ampere_8((cublasLtHandle_t)CUBLAS_CONTEXT->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); + params[10] = static_cast(reinterpret_cast(static_cast(has_error))); + + return 0; + } + + int custom_cdequant_mm_int32_fp16(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 10) + return 1; + int *A = static_cast(params[0]); + float *rowStats = static_cast(params[1]); + float *colStats = static_cast(params[2]); + half *out = static_cast(params[3]); + float *newRowStats = static_cast(params[4]); + float *newcolStats = static_cast(params[5]); + void *row_ptr, *col_ptr; + cudaMallocHost(&row_ptr, sizeof(int)); + cudaMallocHost(&col_ptr, sizeof(int)); + half *bias = static_cast(params[6]); + cudaMemcpy(row_ptr, params[7], sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(col_ptr, params[8], sizeof(int), cudaMemcpyDeviceToHost); + + int numRows = *static_cast(row_ptr); + int numCols = *static_cast(col_ptr); + dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); + + return 0; + } + +#endif + + int custom_cquantize_blockwise_cpu_fp32(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, void *extra) + { + if (nparam != 6) + return 1; + float *code = static_cast(params[0]); + float *A = static_cast(params[1]); + float *absmax = static_cast(params[2]); + unsigned char *out = static_cast(params[3]); + long long blocksize = *(long long *)params[4]; + long long n = *(long long *)params[5]; + quantize_cpu(code, A, absmax, out, blocksize, n); + return 0; + } +} diff --git a/mindnlp/quant/mindbnb/include/AAlloc.h b/mindnlp/quant/mindbnb/include/AAlloc.h new file mode 100644 index 000000000..6c2ae419f --- /dev/null +++ b/mindnlp/quant/mindbnb/include/AAlloc.h @@ -0,0 +1,86 @@ +#pragma once + +#include "Portable.h" + +namespace BinSearch { +namespace Details { + +template +bool isAligned(const T *p, size_t A) +{ + return (reinterpret_cast(p) % A) == 0; +} + +template +struct AlignedVec +{ + AlignedVec() + : m_storage(0) + , m_data(0) + , m_sz(0) + { + } + + static size_t nBytes(size_t sz) + { + return sz * sizeof(T) + A; + } + + static size_t shiftAmt(char *p) + { + return A>1? (A - (reinterpret_cast(p) % A)) % A: 0; + } + + void setPtr(char *p, size_t sz) + { + m_sz = sz; + m_data = reinterpret_cast(p + shiftAmt(p)); + } + + //void setPtr(T *p, size_t sz) + //{ + // m_sz = sz; + // if (A>1) + // myassert(((reinterpret_cast(p) % A) == 0), "bad alignment"); + // m_data = p; + //} + + // internal allocation + void resize(size_t sz) + { + m_storage = new char[nBytes(sz)]; + setPtr(m_storage, sz); + } + + // external allocation + void set(char *storage, size_t sz) + { + setPtr(storage, sz); + } + + ~AlignedVec() + { + if (m_storage) + delete [] m_storage; + } + + size_t size() const { return m_sz; } + T& operator[](size_t i) { return m_data[i]; } + const T& operator[](size_t i) const { return m_data[i]; } + T* begin() { return m_data; } + T* end() { return m_data+m_sz; } + const T* begin() const { return m_data; } + const T* end() const { return m_data+m_sz; } + T& front() { return m_data[0]; } + T& back() { return m_data[m_sz-1]; } + const T& front() const { return m_data[0]; } + const T& back() const { return m_data[m_sz - 1]; } + +private: + char *m_storage; + T *m_data; + size_t m_sz; +}; + +} // namespace Details +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/include/Algo-Direct-Common.h b/mindnlp/quant/mindbnb/include/Algo-Direct-Common.h new file mode 100644 index 000000000..7b40edea9 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/Algo-Direct-Common.h @@ -0,0 +1,341 @@ +#pragma once + +#include +#include +#include +#include "AAlloc.h" + +namespace BinSearch { +namespace Details { + +namespace DirectAux { + +#define SAFETY_MULTI_PASS true + +template +struct HResults +{ + HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {} + T H; + double hRatio; + size_t nInc; +}; + + +#ifdef USE_FMA +template struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); }; +template struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); }; +template struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); }; +#else +template struct IsDirect { static const bool value = (A == Direct); }; +template struct IsDirect2 { static const bool value = (A == Direct2); }; +template struct IsDirectCache { static const bool value = (A == DirectCache); }; +#endif + +// general definition +template +struct BucketElem +{ + FORCE_INLINE void set( uint32 b, const T *) + { + m_b = b; + } + + FORCE_INLINE uint32 index() const { return m_b; } + +private: + uint32 m_b; +}; + +// specialization for DirectCache methods + +template struct MatchingIntType; +template <> struct MatchingIntType { typedef uint64 type; }; +template <> struct MatchingIntType { typedef uint32 type; }; + +template +struct BucketElem::value >::type > +{ + typedef typename MatchingIntType::type I; + + void set(uint32 b, const T *xi) + { + u.u.x = xi[b]; + u.u.b = b; + } + + FORCE_INLINE I index() const { return u.u.b; } + FORCE_INLINE T x() const { return u.u.x; } + +private: + union { + double dummy; + struct + { + T x; + I b; + } u; + } u; +}; + + +template +struct DirectTraits +{ + static void checkH(T scaler, T x0, T xN) + { + T Dn = xN - x0; + T ifmax = Dn * scaler; + myassert((ifmax < std::numeric_limits::max() - (Gap - 1)), + "Problem unfeasible: index size exceeds uint32 capacity:" + << " D[N] =" << Dn + << ", H =" << scaler + << ", H D[n] =" << ifmax << "\n" + ); + } + + FORCE_INLINE static uint32 f(T scaler, T x0, T z) + { + T tmp = scaler * (z - x0); +#ifdef USE_SSE2 + return ftoi(FVec1(tmp)); +#else + return static_cast(tmp); +#endif + } + + template + FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& x0, const FVec& z) + { + return ftoi(scaler*(z-x0)); + } + + static T cst0(T scaler, T x0) + { + return x0; + } +}; + +#ifdef USE_FMA +template +struct DirectTraits +{ + typedef FVec1 fVec1; + + static void checkH(T scaler, T H_Times_x0, T xN) + { + union { + typename FVec1::vec_t v; + T s; + } ifmax; + ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0)); + myassert((ifmax.s < std::numeric_limits::max() - (Gap - 1)), + "Problem unfeasible: index size exceeds uint32 capacity:" + << " H X[0] =" << H_Times_x0 + << ", H =" << scaler + << ", X[N] =" << xN + << ", H X[N] - H X[0] =" << ifmax.s << "\n" + ); + } + + FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi) + { + return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0))); + } + + template + FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& H_Times_X0, const FVec& z) + { + return ftoi(mulSub(scaler, z, H_Times_X0)); + } + + static T cst0(T scaler, T x0) + { + return scaler*x0; + } +}; +#endif + +template +struct DirectInfo +{ + static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA); + typedef DirectTraits fun_t; + typedef BucketElem bucket_t; + typedef AlignedVec bucketvec_t; + + struct Data { + Data() : buckets(0), xi(0), scaler(0), cst0(0) {} + Data( const T *x // for Direct must persist if xws=NULL + , uint32 n + , T H + , bucket_t *bws // assumed to gave size nb, as computed below + , T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap + ) + : buckets(bws) + , scaler(H) + , cst0(fun_t::cst0(H, x[0])) + { + myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned"); + + uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]); + + const uint32 npad = Gap-1; + const uint32 n_sz = n + npad; // size of padded vector + + if (xws) { + myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned"); + std::fill_n(xws, npad, x[0]); // pad in front with x[0] + std::copy(x, x+n, xws + npad); + xi = xws; + } + else { + myassert((Gap==1), "if Gap>1 then X workspace must be provided"); + xi = x; + } + + populateIndex(bws, nb, xi, n_sz, scaler, cst0); + } + + const bucket_t *buckets; + const T *xi; + T scaler; + T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not + } data; + + static T growStep(T H) + { + T step; + T P = next(H); + while ((step = P - H) == 0) + P = next(P); + return step; + } + + static HResults computeH(const T *px, uint32 nx) + { + myassert((nx > Gap), "Array X too small"); + myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap"); + + const T x0 = px[0]; + const T xN = px[nx-1]; + + const T range = xN - x0; + myassert((range < std::numeric_limits::max()), "range too large"); + + // check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i + T deltaDMin = range; + for (uint32 i = Gap; i < nx; ++i) { + T Dnew = px[i] - x0; + T Dold = px[i - Gap] - x0; + myassert((Dnew > Dold), + "Problem unfeasible: D_i sequence not strictly increasing" + << " X[" << 0 << "]=" << x0 + << " X[" << i - Gap << "]=" << px[i - Gap] + << " X[" << i << "]=" << px[i] + << "\n" + ); + T deltaD = Dnew - Dold; + if (deltaD < deltaDMin) + deltaDMin = deltaD; + } + + // initial guess for H + const T H0 = T(1.0) / deltaDMin; + T H = H0; + + T cst0 = fun_t::cst0(H, x0); + fun_t::checkH(H, cst0, xN); + + // adjust H by trial and error until succeed + size_t nInc = 0; + bool modified = false; + size_t npasses = 0; + T step = growStep(H); + uint32 seg_already_checked_from = nx; + do { + myassert((npasses++ < 2), "verification failed\n"); + // if there has been an increase, then check only up to that point + uint32 last_seg_to_be_checked = seg_already_checked_from - 1; + modified = false; + uint32 inew = 0; + for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) { + uint32 iold = fun_t::f(H, cst0, px[i-Gap]); + uint32 inew = fun_t::f(H, cst0, px[i]); + while (inew == iold) { + seg_already_checked_from = i; + last_seg_to_be_checked = nx-1; // everything needs to be checked + modified = true; + H = H + step; + step *= 2; + // recalculate all constants and indices + cst0 = fun_t::cst0(H, x0); + fun_t::checkH(H, cst0, xN); + iold = fun_t::f(H, cst0, px[i - Gap]); + inew = fun_t::f(H, cst0, px[i]); + } + } + } while (SAFETY_MULTI_PASS && modified); + + return HResults(H, (((double)H) / H0) - 1.0, nInc); + } + + static void populateIndex(BucketElem *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0) + { + for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) { + uint32 idx = fun_t::f(scaler, cst0, px[i]); + while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false + buckets[b].set( j, px ); + --b; + } + if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant + j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check + buckets[b].set(j, px); + if (b-- == 0) + break; + } + } + } + + DirectInfo(const Data& d) + : data(d) + { + } + + DirectInfo(const T* px, const uint32 n) + { + HResults res = computeH(px, n); + +#ifdef PAPER_TEST + nInc = res.nInc; + hRatio = res.hRatio; +#endif + const uint32 npad = Gap-1; + const uint32 n_sz = n + npad; // size of padded vector + + if (npad) + xi.resize(n_sz); + + T H = res.H; + T cst0 = fun_t::cst0(H, px[0]); + const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]); + buckets.resize(maxIndex + 1); + + data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL)); + } + +private: + bucketvec_t buckets; + AlignedVec xi; + +#ifdef PAPER_TEST +public: + double hRatio; + size_t nInc; +#endif +}; + + +} // namespace DirectAux +} // namespace Details +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/include/Algo-Direct2.h b/mindnlp/quant/mindbnb/include/Algo-Direct2.h new file mode 100644 index 000000000..547ca9955 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/Algo-Direct2.h @@ -0,0 +1,307 @@ +#pragma once + +#include "Algo-Direct-Common.h" + +namespace BinSearch { +namespace Details { + +template +struct AlgoScalarBase::value>::type> : DirectAux::DirectInfo<2, T, A> +{ +private: + typedef DirectAux::DirectInfo<2, T, A> base_t; + static const size_t Offset=2; + +public: + AlgoScalarBase(const T* x, const uint32 n) + : base_t(x, n) + { + } + + FORCE_INLINE uint32 scalar(T z) const + { + const T* px = base_t::data.xi; + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z); + uint32 iidx = buckets[bidx]; + px += iidx; + if (z < *px) + --iidx; + if (z < *(px+1)) + --iidx; + return iidx; + } +}; + + +template +struct AlgoVecBase::value>::type> : AlgoScalarBase +{ + static const uint32 nElem = sizeof(typename InstrFloatTraits::vec_t) / sizeof(T); + + typedef FVec fVec; + typedef IVec i128; + + struct Constants + { + fVec vscaler; + fVec vcst0; + IVec one; + }; + +private: + typedef AlgoScalarBase base_t; + +#ifdef USE_SSE2 + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + union U { + __m128i vec; + uint32 ui32[4]; + } u; + + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; + + // read indices t + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + +#if 0 + // read pairs ( X(t-1), X(t) ) + __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3)); + __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2)); + __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1)); + __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0)); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); + __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); + __m128 u01 = _mm_unpacklo_ps(h02, h13); + __m128 u23 = _mm_unpackhi_ps(h02, h13); + __m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); + __m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); +#else + __m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2)); + __m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0)); + __m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6)); + __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); +#endif + IVec i(u.vec); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + i = i + vlem + vlep; + i.store(pr); + } + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; + + uint32 b1 = buckets[bidx.get1()]; + uint32 b0 = buckets[bidx.get0()]; + + const double *p1 = &xi[b1]; + const double *p0 = &xi[b0]; + + // read pairs ( X(t-1), X(t) ) + __m128d vx1 = _mm_loadu_pd(p1); + __m128d vx0 = _mm_loadu_pd(p0); + + // build: + // { X(t(0)-1), X(t(1)-1) } + // { X(t(0)), X(t(1)) } + __m128d vxm = _mm_shuffle_pd(vx0, vx1, 0); + __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); + + IVec i(b1, b0); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); + i = i + vlem + vlep; + + union { + __m128i vec; + uint32 ui32[4]; + } u; + u.vec = i; + pr[0] = u.ui32[0]; + pr[1] = u.ui32[2]; + } +#endif // USE_SSE2 + +#ifdef USE_AVX + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; + +#if 0 // use gather instructions + + IVec idxm; + idxm.setidx(buckets, bidx); + __m256i z = _mm256_setzero_si256(); + IVec minusone = _mm256_cmpeq_epi32(z,z); + IVec idxp = idxm - minusone; + + FVec vxm = _mm256_i32gather_ps(xi, idxm, sizeof(float)); + FVec vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float)); + IVec ip = idxm; + +#else // do not use gather instructions + + union U { + __m256i vec; + uint32 ui32[8]; + } u; + + // read indices t + + const double *p7 = reinterpret_cast(&xi[(u.ui32[7] = buckets[bidx.get7()])]); + const double *p6 = reinterpret_cast(&xi[(u.ui32[6] = buckets[bidx.get6()])]); + const double *p5 = reinterpret_cast(&xi[(u.ui32[5] = buckets[bidx.get5()])]); + const double *p4 = reinterpret_cast(&xi[(u.ui32[4] = buckets[bidx.get4()])]); + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + +#if 0 // perform 8 loads in double precision + + // read pairs ( X(t-1), X(t) ) + __m128 xp7 = _mm_castpd_ps(_mm_load_sd(p7)); + __m128 xp6 = _mm_castpd_ps(_mm_load_sd(p6)); + __m128 xp5 = _mm_castpd_ps(_mm_load_sd(p5)); + __m128 xp4 = _mm_castpd_ps(_mm_load_sd(p4)); + __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3)); + __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2)); + __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1)); + __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0)); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m128 h57 = _mm_shuffle_ps(xp5, xp7, (1 << 2) + (1 << 6)); // F- F+ H- H+ + __m128 h46 = _mm_shuffle_ps(xp4, xp6, (1 << 2) + (1 << 6)); // E- E+ G- G+ + __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); // B- B+ D- D+ + __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); // A- A+ C- C+ + + __m128 u01 = _mm_unpacklo_ps(h02, h13); // A- B- A+ B+ + __m128 u23 = _mm_unpackhi_ps(h02, h13); // C- D- C+ D+ + __m128 u45 = _mm_unpacklo_ps(h46, h57); // E- F- E+ F+ + __m128 u67 = _mm_unpackhi_ps(h46, h57); // G- H- G+ H+ + + __m128 abcdm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // A- B- C- D- + __m128 abcdp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // A+ B+ C+ D+ + __m128 efghm = _mm_shuffle_ps(u45, u67, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // E- F- G- H- + __m128 efghp = _mm_shuffle_ps(u45, u67, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // E+ F+ G+ H+ + + FVec vxp = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdm), efghm, 1); + FVec vxm = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdp), efghp, 1); + + IVec ip(u.vec); + +#else // use __mm256_set_pd + + // read pairs ( X(t-1), X(t) ) + __m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) } + __m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) } + + // { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) } + FVec vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) ); + // { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) } + FVec vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) ); + + IVec ip(u.vec); + +#endif + +#endif + + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + ip = ip + vlem + vlep; + + ip.store(pr); + } + + + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + union { + __m256i vec; + uint64 ui64[4]; + } u; + + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; + + // read indices t + const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])]; + const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])]; + const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])]; + const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])]; + + // read pairs ( X(t-1), X(t) ) + __m128d xp3 = _mm_loadu_pd(p3); + __m128d xp2 = _mm_loadu_pd(p2); + __m128d xp1 = _mm_loadu_pd(p1); + __m128d xp0 = _mm_loadu_pd(p0); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1); + __m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1); + FVec vxm = _mm256_unpacklo_pd(x02,x13); + FVec vxp = _mm256_unpackhi_pd(x02,x13); + + +// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0); +// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0); +// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3); +// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3); +// FVec vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1); +// FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); + + IVec i(u.vec); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + i = i + vlem + vlep; + i.extractLo32s().store(pr); + } +#endif + +public: + + AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {} + + void initConstants(Constants& cst) const + { + cst.vscaler.setN(base_t::data.scaler); + cst.vcst0.setN(base_t::data.cst0); + cst.one.setN(uint32(1)); + } + + void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + { + fVec vz(pz); + resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr); + } +}; +} // namespace Details +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/include/AlgoXCodes.h b/mindnlp/quant/mindbnb/include/AlgoXCodes.h new file mode 100644 index 000000000..bdc9b00b6 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/AlgoXCodes.h @@ -0,0 +1,23 @@ +ALGOENUM(DirectCacheFMA, 5) +ALGOENUM(DirectFMA, 15) +ALGOENUM(Direct2FMA, 25) +ALGOENUM(DirectCache, 10) +ALGOENUM(Direct, 20) +ALGOENUM(Direct2, 30) +ALGOENUM(Nonary, 40) +ALGOENUM(Pentary, 50) +ALGOENUM(Ternary, 60) +ALGOENUM(Eytzinger, 70) +ALGOENUM(BitSet, 80) +ALGOENUM(ClassicOffset, 90) +#ifdef PAPER_TEST +ALGOENUM(MorinOffset, 100) +ALGOENUM(BitSetNoPad, 110) +ALGOENUM(ClassicMod, 120) +ALGOENUM(MorinBranchy, 130) +ALGOENUM(Classic, 140) +ALGOENUM(LowerBound, 145) +#ifdef USE_MKL +ALGOENUM(MKL, 150) +#endif +#endif diff --git a/mindnlp/quant/mindbnb/include/BinAlgo.h b/mindnlp/quant/mindbnb/include/BinAlgo.h new file mode 100644 index 000000000..aac67a0c7 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/BinAlgo.h @@ -0,0 +1,77 @@ +#pragma once + +#include "Type.h" +#include + +namespace BinSearch { + +template +struct BinAlgo : Details::BinAlgoBase +{ + typedef Details::BinAlgoBase base_t; + + BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {} + BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {} + + FORCE_INLINE + uint32 scalar(T z) const + { + if (!L || z >= x0) + if (!R || z < xN) + return base_t::scalar(z); + else + return N; + else + return std::numeric_limits::max(); + } + + + FORCE_INLINE + void vectorial(uint32 *pr, const T *pz, uint32 n) const + { + if (!L && !R) { + Details::Loop::loop(*this, pr, pz, n); + } + else { + const uint32 nElem = base_t::nElem; + const uint32 idealbufsize = 256; + const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0)); + T databuf[bufsize]; + uint32 resbuf[bufsize]; + uint32 indexbuf[bufsize]; + + uint32 *prend = pr + n; + while(pr != prend) { + uint32 cnt = 0; + uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend)); + for (uint32 j = 0; j < niter; ++j) { + T z = pz[j]; + // FIXME: use SSE2? + if (!L || z >= x0) + if (!R || z < xN) { + databuf[cnt] = z; + indexbuf[cnt] = j; + ++cnt; + } + else + pr[j] = N; + else + pr[j] = std::numeric_limits::max(); + } + // FIXME: merge these two loops + Details::Loop::loop(*this, resbuf, databuf, cnt); + for (uint32 j = 0; j < cnt; ++j) + pr[indexbuf[j]] = resbuf[j]; + pr += niter; + pz += niter; + } + } + } + + Details::CondData x0; + Details::CondData xN; + Details::CondData N; +}; + + +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/include/BinSearch.h b/mindnlp/quant/mindbnb/include/BinSearch.h new file mode 100644 index 000000000..336f52963 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/BinSearch.h @@ -0,0 +1,11 @@ +#pragma once + +#include "AAlloc.h" +#include "BinAlgo.h" +#include "SIMD.h" + +#include +#include + + +#include "Algo-Direct2.h" diff --git a/mindnlp/quant/mindbnb/include/Portable.h b/mindnlp/quant/mindbnb/include/Portable.h new file mode 100644 index 000000000..090a25065 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/Portable.h @@ -0,0 +1,182 @@ +#pragma once +#include +#include +#include +#include + +#if defined(__aarch64__) +#ifdef __CUDACC__ +#undef USE_NEON // Doesn't work with nvcc, undefined symbols +#else +#include +#undef USE_NEON // Not yet implemented +#endif +#undef USE_AVX // x86_64 only +#undef USE_AVX2 // x86_64 only +#undef USE_SSE2 // x86_64 only +#undef USE_SSE41 // x86_64 only +#undef USE_SSE42 // x86_64 only +#undef USE_FMA // x86_64 only +#ifdef USE_NEON +typedef float32x4_t __m128; +typedef int32x4_t __m128i; +typedef float64x2_t __m128d; +#else +typedef struct {float a; float b; float c; float d;} __m128; +typedef struct {int a; int b; int c; int d;} __m128i; +typedef struct {double a; double b;} __m128d; +#endif +#else +#undef USE_NEON // ARM64 only +#ifdef __FMA__ +#define USE_FMA +#endif +#if !defined(__SSE2__) && !defined(_MSC_VER) +#error Compiler must support SSE2 +#endif +#define USE_SSE2 + +#if defined(__aarch64__) +#else +#ifdef __AVX2__ +#define USE_AVX2 +#endif + +#ifdef __AVX__ +#define USE_AVX +#endif + + +#ifdef __SSE4_1__ +#define USE_SSE41 +#endif + +#ifdef __SSE4_2__ +#define USE_SSE42 +#endif +#endif +#endif + +#ifndef _MSC_VER +#include +#endif + +namespace BinSearch { + +#ifndef _MSC_VER +typedef int8_t int8; +typedef uint8_t uint8; +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#else +typedef __int8 int8; +typedef unsigned __int8 uint8; +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#endif + +namespace Details { + +#define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); } + +// log2 is not defined in VS2008 +#if defined(_MSC_VER) +inline uint32 log2 (uint32 val) { + if (val == 1) return 0; + uint32 ret = 0; + do { + ret++; + val >>= 1; + } while (val > 1); + return ret; +} +#endif + +#ifdef _DEBUG +#define DEBUG +#endif + +#ifdef _MSC_VER +# define FORCE_INLINE __forceinline +# define NO_INLINE __declspec(noinline) +#else +# define NO_INLINE __attribute__((noinline)) +# ifdef DEBUG +# define FORCE_INLINE NO_INLINE +# else +# define FORCE_INLINE __attribute__((always_inline)) inline +# endif +#endif + +#ifdef USE_AVX +#define COMISS "vcomiss" +#define COMISD "vcomisd" +#else +#define COMISS "comiss" +#define COMISD "comisd" +#endif + +// nextafter is not defined in VS2008 +#if defined(_MSC_VER) && (_MSC_VER <= 1500) +#include +inline float mynext(float x) +{ + return _nextafterf(x, std::numeric_limits::max()); +} + +inline double mynext(double x) +{ + return _nextafter(x, std::numeric_limits::max()); +} +inline float myprev(float x) +{ + return _nextafterf(x, -std::numeric_limits::max()); +} + +inline double myprev(double x) +{ + return _nextafter(x, -std::numeric_limits::max()); +} +#else +inline float mynext(float x) +{ + return std::nextafterf(x, std::numeric_limits::max()); +} + +inline double mynext(double x) +{ + return std::nextafter(x, std::numeric_limits::max()); +} +inline float myprev(float x) +{ + return std::nextafterf(x, -std::numeric_limits::max()); +} + +inline double myprev(double x) +{ + return std::nextafter(x, -std::numeric_limits::max()); +} +#endif + +template +inline T next(T x) +{ + for (int i = 0; i < 4; ++i) + x = mynext(x); + return x; +} + +template +inline T prev(T x) +{ + for (int i = 0; i < 4; ++i) + x = myprev(x); + return x; +} + +} // namespace Details +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/include/SIMD.h b/mindnlp/quant/mindbnb/include/SIMD.h new file mode 100644 index 000000000..9d1410c73 --- /dev/null +++ b/mindnlp/quant/mindbnb/include/SIMD.h @@ -0,0 +1,589 @@ +#pragma once + +#include "Portable.h" + +#ifdef USE_SSE2 +#include +#if defined(USE_AVX) || defined(USE_AVX2) +#include +#else +#ifdef USE_SSE41 +#include +#endif +#endif +#endif + +namespace BinSearch { +namespace Details { + +template +struct FTOITraits{}; + +template +struct FVec; + +template +struct IVec; + +template +struct FVec1; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +} +} + +#if !defined(__aarch64__) +#ifdef USE_SSE42 +#ifndef _MSC_VER +#include +#define popcnt32 _mm_popcnt_u32 +#else +#include +#define popcnt32 __popcnt +#endif +#else // USE_SSE42 +namespace BinSearch { +FORCE_INLINE int popcnt32(int x32) +{ + // strictly speaking this is not correct, as it ignores higher order bits + // however, this is only used on the resuot of movemask on a 128-bit register, which is 8 at most, so it is ok + // with 256-bit registers, SSE42 is defined, and we do not use this function + uint8 x = static_cast(x32); + x = (x & 0x55) + (x >> 1 & 0x55); + x = (x & 0x33) + (x >> 2 & 0x33); + x = (x & 0x0f) + (x >> 4 & 0x0f); + return x; +} +} // namespace +#endif + +#include "Type.h" + +namespace BinSearch { +namespace Details { + +template <> struct InstrIntTraits +{ + typedef __m128i vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +template <> +struct FTOITraits +{ + typedef IVec vec_t; +}; + +#ifdef USE_AVX + +template <> +struct FTOITraits +{ + typedef IVec vec_t; +}; + +template <> struct InstrIntTraits +{ + typedef __m256i vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m256 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m256d vec_t; +}; + +#endif + + +template +struct VecStorage +{ + typedef typename TR::vec_t vec_t; + + FORCE_INLINE operator vec_t&() { return vec; } + FORCE_INLINE operator const vec_t&() const { return vec; } + +protected: + FORCE_INLINE VecStorage() {} + FORCE_INLINE VecStorage(const vec_t& v) : vec( v ) {} + + vec_t vec; +}; + +template +struct IVecBase; + +template <> +struct IVecBase : VecStorage> +{ +protected: + FORCE_INLINE IVecBase() {} + FORCE_INLINE IVecBase( const vec_t& v) : VecStorage>( v ) {} +public: + FORCE_INLINE static vec_t zero() { return _mm_setzero_si128(); } + + FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32( vec ); } + + FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) + { +#ifdef USE_SSE41 + vec = _mm_blendv_epi8(vec, val, mask); +#else + vec = _mm_or_si128(_mm_andnot_si128(mask,vec), _mm_and_si128(mask,val)); +#endif + } + FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask) + { + vec = _mm_or_si128(vec, _mm_and_si128(val,mask)); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm_set1_epi32( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec( uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase( _mm_set_epi32( u3, u2, u1, u0 ) ) {} + + void setN( int32 i ) { vec = _mm_set1_epi32( i ); } + +#ifdef USE_SSE41 + FORCE_INLINE int32 get1() const { return _mm_extract_epi32(vec, 1); } + FORCE_INLINE int32 get2() const { return _mm_extract_epi32(vec, 2); } + FORCE_INLINE int32 get3() const { return _mm_extract_epi32(vec, 3); } +#else + FORCE_INLINE int32 get1() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 1 ) ); } + FORCE_INLINE int32 get2() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); } + FORCE_INLINE int32 get3() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 3 ) ); } +#endif + + FORCE_INLINE void store( uint32 *pi ) const { _mm_storeu_si128( reinterpret_cast(pi), vec ); } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm_movemask_ps(_mm_castsi128_ps(vec))); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm_set1_epi64x( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec( uint64 u1, uint64 u0 ) : IVecBase( _mm_set_epi64x(u1, u0) ) {} + + void setN( int32 i ) { vec = _mm_set1_epi64x( i ); } + + FORCE_INLINE int32 get1() const + { +#ifdef USE_SSE41 + return _mm_extract_epi32(vec, 2); +#else + return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); +#endif + } + + // extract the 2 32 bits integers no. 0, 2 and store them in a __m128i + FORCE_INLINE IVec extractLo32s() const + { + return _mm_shuffle_epi32(vec, ((2 << 2) | 0)); + } + + FORCE_INLINE void store( uint32 *pi ) const + { + pi[0] = get0(); + pi[1] = get1(); + } + + FORCE_INLINE int countbit() + { +#if 1 + // takes 4 cycles + __m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle + __m128i s = _mm_add_epi32(vec, hi); + int32 x = _mm_cvtsi128_si32(s); + return -x; +#else + // takes 6 cycles + return popcnt32(_mm_movemask_pd(_mm_castsi128_pd(vec))); +#endif + } +}; + +template +FORCE_INLINE IVec operator>> (const IVec& a, unsigned n) { return _mm_srli_epi32(a, n); } +template +FORCE_INLINE IVec operator<< (const IVec& a, unsigned n) { return _mm_slli_epi32(a, n); } +template +FORCE_INLINE IVec operator& (const IVec& a, const IVec& b ) { return _mm_and_si128( a, b ); } +template +FORCE_INLINE IVec operator| (const IVec& a, const IVec& b ) { return _mm_or_si128( a, b ); } +template +FORCE_INLINE IVec operator^ (const IVec& a, const IVec& b ) { return _mm_xor_si128( a, b ); } +template +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm_add_epi32( a, b ); } +template +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm_sub_epi32( a, b ); } +#ifdef USE_SSE41 +template +FORCE_INLINE IVec min (const IVec& a, const IVec& b ) { return _mm_min_epi32( a, b ); } +#endif + +typedef VecStorage> FVec128Float; + +template <> +struct FVec1 : FVec128Float +{ + FORCE_INLINE FVec1() {} + FORCE_INLINE FVec1( float f ) : FVec128Float( _mm_load_ss( &f ) ) {} + FORCE_INLINE FVec1( const vec_t& v ): FVec128Float( v ) {} + + FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); } +}; + +template <> +struct FVec : FVec128Float +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( float f ) : FVec128Float( _mm_set1_ps( f ) ) {} + FORCE_INLINE FVec( const float *v ) : FVec128Float( _mm_loadu_ps( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec128Float(v) {} + FORCE_INLINE FVec( float f3, float f2, float f1, float f0 ) : FVec128Float( _mm_set_ps(f3, f2, f1, f0) ) {} + + void set0( float f ) { vec = _mm_load_ss( &f ); } + void setN( float f ) { vec = _mm_set1_ps( f ); } + + FORCE_INLINE void setidx( const float *xi, const IVec& idx ) + { + uint32 i0 = idx.get0(); + uint32 i1 = idx.get1(); + uint32 i2 = idx.get2(); + uint32 i3 = idx.get3(); + vec = _mm_set_ps( xi[i3], xi[i2], xi[i1], xi[i0] ); + } + + FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); } + FORCE_INLINE float get1() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 1 ) ); } + FORCE_INLINE float get2() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 2 ) ); } + FORCE_INLINE float get3() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 3 ) ); } +}; + +FORCE_INLINE FVec1 operator+ (const FVec1& a, const FVec1& b) { return _mm_add_ss( a, b ); } +FORCE_INLINE FVec1 operator- (const FVec1& a, const FVec1& b) { return _mm_sub_ss( a, b ); } +FORCE_INLINE FVec1 operator* (const FVec1& a, const FVec1& b) { return _mm_mul_ss( a, b ); } +FORCE_INLINE FVec1 operator/ (const FVec1& a, const FVec1& b) { return _mm_div_ss( a, b ); } +FORCE_INLINE int ftoi (const FVec1& a) { return _mm_cvttss_si32(a); } +FORCE_INLINE IVec operator> (const FVec1& a, const FVec1& b) { return _mm_castps_si128( _mm_cmpgt_ss( a, b ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec1 mulSub(const FVec1& a, const FVec1& b, const FVec1& c) { return _mm_fmsub_ss(a, b, c); } +#endif + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm_sub_ps( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_ps( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_ps( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttps_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); } +#endif +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm_fmsub_ps(a, b, c); } +#endif + +typedef VecStorage> FVec128Double; + +template <> +struct FVec1 : FVec128Double +{ + FORCE_INLINE FVec1() {} + FORCE_INLINE FVec1( double f ) : FVec128Double( _mm_load_sd( &f ) ) {} + FORCE_INLINE FVec1( const vec_t& v ) : FVec128Double( v ) {} + + FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); } +}; + +template <> +struct FVec : FVec128Double +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( double d ) : FVec128Double( _mm_set1_pd( d ) ) {} + FORCE_INLINE FVec( const double *v ) : FVec128Double( _mm_loadu_pd( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec128Double( v ) {} + FORCE_INLINE FVec( double f1, double f0 ) : FVec128Double( _mm_set_pd(f1, f0) ) {} + + void set0( double f ) { vec = _mm_load_sd( &f ); } + void setN( double f ) { vec = _mm_set1_pd( f ); } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm_set_pd( xi[idx.get1()], xi[idx.get0()] ); + } + + FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); } + FORCE_INLINE double get1() const { return _mm_cvtsd_f64( _mm_shuffle_pd( vec, vec, 1 ) ); }; +}; + +FORCE_INLINE FVec1 operator+ (const FVec1& a, const FVec1& b) { return _mm_add_sd( a, b ); } +FORCE_INLINE FVec1 operator- (const FVec1& a, const FVec1& b) { return _mm_sub_sd( a, b ); } +FORCE_INLINE FVec1 operator* (const FVec1& a, const FVec1& b) { return _mm_mul_sd( a, b ); } +FORCE_INLINE FVec1 operator/ (const FVec1& a, const FVec1& b) { return _mm_div_sd( a, b ); } +FORCE_INLINE int ftoi (const FVec1& a) { return _mm_cvttsd_si32(a); } +FORCE_INLINE IVec operator> (const FVec1& a, const FVec1& b) { return _mm_castpd_si128( _mm_cmpgt_sd( a, b ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec1 mulSub(const FVec1& a, const FVec1& b, const FVec1& c) { return _mm_fmsub_sd(a, b, c); } +#endif + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm_sub_pd( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_pd( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_pd( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttpd_epi32(a); } +#ifndef __clang__ // Conflicts with builtin operator +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); } +#endif +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c ) { return _mm_fmsub_pd(a, b, c); } +#endif + +#ifdef USE_AVX + +template <> +struct IVecBase : VecStorage> +{ +protected: + FORCE_INLINE IVecBase() {} + FORCE_INLINE IVecBase( const vec_t& v) : VecStorage>( v ) {} +public: + FORCE_INLINE static vec_t zero() { return _mm256_setzero_si256(); } + + FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32(_mm256_castsi256_si128(vec)); } + + FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) { vec = _mm256_blendv_epi8(vec, val, mask); } + FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask) + { + vec = _mm256_blendv_epi8(vec, val, mask); + //vec = _mm256_or_si256(vec, _mm256_and_si256(val,mask)); + } + + FORCE_INLINE __m128i lo128() const { return _mm256_castsi256_si128(vec); } + FORCE_INLINE __m128i hi128() const { return _mm256_extractf128_si256(vec, 1); } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm256_set1_epi32( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec(uint32 u7, uint32 u6, uint32 u5, uint32 u4, uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase(_mm256_set_epi32(u7, u6, u5, u4, u3, u2, u1, u0)) {} + + void setN( int32 i ) { vec = _mm256_set1_epi32( i ); } + + FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 1); } + FORCE_INLINE int32 get2() const { return _mm256_extract_epi32(vec, 2); } + FORCE_INLINE int32 get3() const { return _mm256_extract_epi32(vec, 3); } + FORCE_INLINE int32 get4() const { return _mm256_extract_epi32(vec, 4); } + FORCE_INLINE int32 get5() const { return _mm256_extract_epi32(vec, 5); } + FORCE_INLINE int32 get6() const { return _mm256_extract_epi32(vec, 6); } + FORCE_INLINE int32 get7() const { return _mm256_extract_epi32(vec, 7); } + + FORCE_INLINE void setidx( const uint32 *bi, const IVec& idx ) + { + vec = _mm256_i32gather_epi32(reinterpret_cast(bi), idx, sizeof(uint32)); + } + + FORCE_INLINE void store( uint32 *pi ) const { _mm256_storeu_si256( reinterpret_cast(pi), vec ); } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm256_movemask_ps(_mm256_castsi256_ps(vec))); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm256_set1_epi64x( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec(uint64 u3, uint64 u2, uint64 u1, uint64 u0) : IVecBase(_mm256_set_epi64x(u3, u2, u1, u0)) {} + + void setN( int32 i ) { vec = _mm256_set1_epi64x( i ); } + + // extract the 4 32 bits integers no. 0, 2, 4, 6 and store them in a __m128i + FORCE_INLINE IVec extractLo32s() const + { + union { + uint32 u32[4]; + __m128i u; + } mask = {0,2,4,6}; + //__m256 ps256 = _mm256_castsi256_ps(vec); + //__m128 lo128 = _mm256_castps256_ps128(ps256); + //__m128 hi128 = _mm256_extractf128_ps(ps256, 1); + //__m128 blend = _mm_shuffle_ps(lo128, hi128, 0 + (2<<2) + (0<<4) + (2<<6)); + __m256i blend = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(mask.u)); + return _mm256_castsi256_si128(blend); + } + + //int32 get1() const { return _mm256_cvtsi256_si32( _mm256_shuffle_epi32( vec, 2 ) ); }; + FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 2); } + + FORCE_INLINE void store( uint32 *pi ) const + { + extractLo32s().store(pi); + } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm256_movemask_pd(_mm256_castsi256_pd(vec))); + } +}; + +template +FORCE_INLINE IVec operator>> (const IVec& a, unsigned n) { return _mm256_srli_epi32(a, n); } +template +FORCE_INLINE IVec operator<< (const IVec& a, unsigned n) { return _mm256_slli_epi32(a, n); } +template +FORCE_INLINE IVec operator& (const IVec& a, const IVec& b ) { return _mm256_and_si256( a, b ); } +template +FORCE_INLINE IVec operator| (const IVec& a, const IVec& b ) { return _mm256_or_si256( a, b ); } +template +FORCE_INLINE IVec operator^ (const IVec& a, const IVec& b ) { return _mm256_xor_si256( a, b ); } +template +FORCE_INLINE IVec min (const IVec& a, const IVec& b ) { return _mm256_min_epi32( a, b ); } + +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm256_add_epi32( a, b ); } +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm256_sub_epi32( a, b ); } +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm256_add_epi64( a, b ); } +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm256_sub_epi64( a, b ); } + + +typedef VecStorage> FVec256Float; + +template <> +struct FVec : FVec256Float +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( float f ) : FVec256Float( _mm256_set1_ps( f ) ) {} + FORCE_INLINE FVec( const float *v ) : FVec256Float( _mm256_loadu_ps( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec256Float(v) {} + FORCE_INLINE FVec(float f7, float f6, float f5, float f4, float f3, float f2, float f1, float f0) : FVec256Float(_mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0)) {} + + //void set0( float f ) { vec = _mm256_load_ss( &f ); } + void setN( float f ) { vec = _mm256_set1_ps( f ); } + + FORCE_INLINE void setidx( const float *xi, const IVec& idx ) + { +#if 1 // use gather primitives + vec = _mm256_i32gather_ps (xi, idx, 4); +#elif 0 + uint32 i0 = idx.get0(); + uint32 i1 = idx.get1(); + uint32 i2 = idx.get2(); + uint32 i3 = idx.get3(); + uint32 i4 = idx.get4(); + uint32 i5 = idx.get5(); + uint32 i6 = idx.get6(); + uint32 i7 = idx.get7(); + vec = _mm256_set_ps( xi[i7], xi[i6], xi[i5], xi[i4], xi[i3], xi[i2], xi[i1], xi[i0] ); +#else + union { + __m256i vec; + uint32 ui32[8]; + } i; + i.vec = static_cast(idx); + vec = _mm256_set_ps(xi[i.ui32[7]], xi[i.ui32[6]], xi[i.ui32[5]], xi[i.ui32[4]], xi[i.ui32[3]], xi[i.ui32[2]], xi[i.ui32[1]], xi[i.ui32[0]]); +#endif + } + + FORCE_INLINE FVec lo128() const { return _mm256_castps256_ps128(vec); } + FORCE_INLINE FVec hi128() const { return _mm256_extractf128_ps(vec, 1); } + + //FORCE_INLINE float get0() const { return _mm256_cvtss_f32( vec ); } + //FORCE_INLINE float get1() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 1 ) ); } + //FORCE_INLINE float get2() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 2 ) ); } + //FORCE_INLINE float get3() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 3 ) ); } +}; + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm256_sub_ps( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm256_mul_ps( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm256_div_ps( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm256_cvttps_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_LE_OS) ); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_GE_OS ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm256_castps_si256(_mm256_cmp_ps(a, b, _CMP_LT_OS )); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm256_fmsub_ps(a, b, c); } +#endif + +typedef VecStorage> FVec256Double; + +template <> +struct FVec : FVec256Double +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( double d ) : FVec256Double( _mm256_set1_pd( d ) ) {} + FORCE_INLINE FVec( const double *v ) : FVec256Double( _mm256_loadu_pd( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec256Double( v ) {} + FORCE_INLINE FVec(double d3, double d2, double d1, double d0) : FVec256Double(_mm256_set_pd(d3, d2, d1, d0)) {} + + //void set0( double f ) { vec = _mm256_load_sd( &f ); } + void setN( double f ) { vec = _mm256_set1_pd( f ); } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm256_i32gather_pd(xi, idx, 8); + } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm256_i64gather_pd(xi, idx, 8); + } + +// FORCE_INLINE double get0() const { return _mm256_cvtsd_f64( vec ); } +// FORCE_INLINE double get1() const { return _mm256_cvtsd_f64( _mm256_shuffle_pd( vec, vec, 1 ) ); }; +}; + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm256_sub_pd( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm256_mul_pd( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm256_div_pd( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm256_cvttpd_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_LE_OS ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd(a, b, _CMP_LT_OS)); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_GE_OS ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm256_fmsub_pd(a, b, c); } +#endif + +#endif + +} // namespace Details +} // namespace BinSearch +#endif // !defined(__aarch64__) diff --git a/mindnlp/quant/mindbnb/include/Type.h b/mindnlp/quant/mindbnb/include/Type.h new file mode 100644 index 000000000..16bf3e3ae --- /dev/null +++ b/mindnlp/quant/mindbnb/include/Type.h @@ -0,0 +1,221 @@ + #pragma once + +#include +#include +#include + +#include "Portable.h" + +using std::size_t; + +namespace BinSearch { + +enum InstrSet { Scalar, SSE, AVX, Neon }; + +#define ALGOENUM(x, b) x, +enum Algos + { +#include "AlgoXCodes.h" + }; +#undef ALGOENUM + +namespace Details { + + template + struct InstrIntTraits; + + template + struct InstrFloatTraits; + + // base class for algorithm supporting the method: + // uint32 scalar(T z) const + template + struct AlgoScalarBase; + + // base class for algorithm supporting the following methods, constants and definitions: + // static const uint32 nElem + // struct Constants; + // void initConstants(Constants& cst) const + // void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + // The function vectorial processes nElem items + template + struct AlgoVecBase; + + template struct IntTraits; + + template <> struct IntTraits + { + typedef uint32 itype; + }; + template <> struct IntTraits + { + typedef uint64 itype; + }; + + template + struct Body + { + template + FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst) + { + e.vectorial(ri, zi, cst); + Body::template iteration(e, ri + D, zi + D, cst); + } + + }; + + template <> + struct Body<0> + { + template + FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&) + { + } + }; + + template + struct Loop + { + typedef Algo algo_type; + static const uint32 M = 4; + static const uint32 D = algo_type::nElem; + + FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n) + { + typename algo_type::Constants cst; + e.initConstants(cst); + + uint32 j = 0; + while (j + (D*M) <= n) { + Details::Body::template iteration(e, ri + j, zi + j, cst); + j += (D*M); + } + while (j + D <= n) { + e.vectorial(ri + j, zi + j, cst); + j += D; + } + while (D > 1 && j < n) { + ri[j] = e.scalar(zi[j]); + j += 1; + } + } + }; + + template + struct _Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + e.template run(d); + _Pipeliner::go(e, d); + } + }; + + template + struct _Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + } + }; + + template + struct Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + _Pipeliner::go(e, d); + } + }; + + +#if 1 + template + char is_complete_impl(char (*)[sizeof(T)]); + + template + long is_complete_impl(...); + + template + struct IsComplete + { + static const bool value = sizeof(is_complete_impl(0)) == sizeof(char); + }; +#else + template + std::true_type is_complete_impl(T *); + + std::false_type is_complete_impl(...); + + template + struct IsComplete : decltype(is_complete_impl(std::declval())) {}; +#endif + +template +struct AlgoScalarToVec : AlgoScalarBase +{ + typedef AlgoScalarBase base_t; + + AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {} + AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {} + + static const uint32 nElem = 1; + + struct Constants + { + }; + + void initConstants(Constants& cst) const + { + } + + FORCE_INLINE + void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + { + *pr = base_t::scalar(*pz); + } +}; + +template +struct conditional { typedef T type; }; + +template +struct conditional { typedef F type; }; + +template +struct CondData +{ + FORCE_INLINE CondData(T x) : v(x) {} + FORCE_INLINE operator const T&() const { return v;} +private: + T v; +}; + +template +struct CondData +{ + FORCE_INLINE CondData(T) {} + FORCE_INLINE operator const T() const { return 0;} +}; + +template +struct BinAlgoBase : Details::conditional< Details::IsComplete>::value + , Details::AlgoVecBase + , Details::AlgoScalarToVec + >::type +{ + typedef typename Details::conditional< Details::IsComplete>::value + , Details::AlgoVecBase + , Details::AlgoScalarToVec + >::type base_t; + + BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {} + BinAlgoBase(const typename base_t::Data& d) : base_t(d) {} +}; + +} // namespace Details + +} // namespace BinSearch diff --git a/mindnlp/quant/mindbnb/integrations/__init__.py b/mindnlp/quant/mindbnb/integrations/__init__.py new file mode 100644 index 000000000..17824e3e4 --- /dev/null +++ b/mindnlp/quant/mindbnb/integrations/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + integrations +''' +from .quantization_bnb_8bit import quant_8bit diff --git a/mindnlp/quant/mindbnb/integrations/quantization_bnb_8bit.py b/mindnlp/quant/mindbnb/integrations/quantization_bnb_8bit.py new file mode 100644 index 000000000..a30d5ad8e --- /dev/null +++ b/mindnlp/quant/mindbnb/integrations/quantization_bnb_8bit.py @@ -0,0 +1,29 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' + quantization bnb 8bit +''' +from .replace_modules import replace_with_bnb_linear + + +def quant_8bit(model, modules_to_not_convert=None, quantization_config=None): + + model = replace_with_bnb_linear( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + return model diff --git a/mindnlp/quant/mindbnb/integrations/replace_modules.py b/mindnlp/quant/mindbnb/integrations/replace_modules.py new file mode 100644 index 000000000..29864d4d9 --- /dev/null +++ b/mindnlp/quant/mindbnb/integrations/replace_modules.py @@ -0,0 +1,128 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" + replace modules +""" +# pylint: disable=E0611, E0401 +import logging +import mindspore + +from bitsandbytes.nn.modules import Int8Params +import bitsandbytes as bnb + +from mindnlp.core import nn + +logger = logging.getLogger(__name__) + + +def _replace_with_bnb_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, + llm_int8_has_fp16_weight=False, + llm_int8_threshold=6.0, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if ( + isinstance(module, (nn.Conv1d, nn.Linear)) + and name not in modules_to_not_convert + ): + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) + for key in modules_to_not_convert + ): + # Initialize empty weights if necessary (replace with MindSpore equivalent if required) + if isinstance(module, nn.Conv1d): + in_features, out_features = module.weight.shape + else: + in_features = module.in_features + out_features = module.out_features + + weight = model._modules[name].weight.clone() + bias = ( + model._modules[name].bias.clone() + if model._modules[name].bias is not None + else None + ) + + # Replace with MindSpore equivalent or custom module + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + has_fp16_weights=llm_int8_has_fp16_weight, + threshold=llm_int8_threshold, + ) + + model._modules[name].weight = Int8Params( + weight.data, + requires_grad=llm_int8_has_fp16_weight, + has_fp16_weights=llm_int8_has_fp16_weight, + ) + if bias is not None: + model._modules[name].bias = mindspore.Parameter(bias) + + model._modules[name].quant() + has_been_replaced = True + + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad = False + + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + + return model, has_been_replaced + + +def replace_with_bnb_linear( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None +): + modules_to_not_convert = ( + ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + ) + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/mindnlp/quant/mindbnb/requirements-dev.txt b/mindnlp/quant/mindbnb/requirements-dev.txt new file mode 100644 index 000000000..2536d5979 --- /dev/null +++ b/mindnlp/quant/mindbnb/requirements-dev.txt @@ -0,0 +1,10 @@ +# Requirements used for local development +setuptools>=63 +pytest~=8.2.2 +einops~=0.8.0 +wheel~=0.43.0 +scipy~=1.13.0 +pandas~=2.2.2 +matplotlib~=3.9.1 +mindspore>=2.3 +mindnlp>=0.4 \ No newline at end of file diff --git a/mindnlp/quant/mindbnb/scripts/build.sh b/mindnlp/quant/mindbnb/scripts/build.sh new file mode 100644 index 000000000..7063ff7c5 --- /dev/null +++ b/mindnlp/quant/mindbnb/scripts/build.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# 安装构建工具和 cmake +# sudo apt-get update +# sudo apt-get install -y build-essential cmake + +# 安装 Python 依赖 +pip install -r requirements-dev.txt + +# 使用 cmake 配置并构建项目 +cmake -DCOMPUTE_BACKEND=cuda -S . + +# 构建项目 +make -j4 \ No newline at end of file diff --git a/mindnlp/quant/mindbnb/tests/test_bitsandbytes_linear.py b/mindnlp/quant/mindbnb/tests/test_bitsandbytes_linear.py new file mode 100644 index 000000000..6900229d7 --- /dev/null +++ b/mindnlp/quant/mindbnb/tests/test_bitsandbytes_linear.py @@ -0,0 +1,46 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import torch +import torch.nn as nn +import numpy as np + +from bitsandbytes.nn import Linear8bitLt + +# 设置设备 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +np.random.seed(42) + +# 创建模型 +int8_model = Linear8bitLt(2, 4, has_fp16_weights=False) + +# 初始化权重和偏置 +weight = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float16) +bias = torch.tensor([1, 2], dtype=torch.float16) + +# 设置模型的权重和偏置 +int8_model.weight.data = weight +int8_model.bias.data = bias + +int8_model.to(device) + +# 输入数据 +input_data = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float16).to(device) + +# 模型输出 +int8_output = int8_model(input_data) + +print(int8_output) diff --git a/mindnlp/quant/mindbnb/tests/test_matmul.py b/mindnlp/quant/mindbnb/tests/test_matmul.py new file mode 100644 index 000000000..32b73d0c8 --- /dev/null +++ b/mindnlp/quant/mindbnb/tests/test_matmul.py @@ -0,0 +1,66 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import mindspore +import numpy as np +import time + +from bitsandbytes import matmul +from mindspore import Tensor, ops +from mindspore._c_expression import _framework_profiler_step_start +from mindspore._c_expression import _framework_profiler_step_end + +mindspore.context.set_context(device_target="GPU") + +a = Tensor(np.random.randn(8192, 8192).astype(np.float16)) +b = Tensor(np.random.randn(8192, 8192).astype(np.float16)) +b_ops = b.t() +# for i in range(5): +# c_old = ops.matmul(a, b_ops) +c_old = ops.matmul(a, b_ops) + +start = time.time() +# # _framework_profiler_step_start() +# # profiler = mindspore.Profiler() +for i in range(10): + c_old = ops.matmul(a, b_ops) +# # _framework_profiler_step_end() +# c_old = ops.matmul(a, b.t()) +tick = time.time() +time_ops = tick - start +# for i in range(5): +# c_new = matmul(a, b) +c_new = matmul(a, b) +start = time.time() +# _framework_profiler_step_start() +for i in range(10): + c_new = matmul(a, b) +# c_new = matmul(a, b) +# _framework_profiler_step_end() +tick = time.time() +time_bnb = tick - start +# profiler.analyse() +# print(c_new) +# print(c_old) +print("ops.matmul time: ", time_ops) +print("bnb.matmul time: ", time_bnb) + +# while True: +# # c_old = matmul(a, b) +# c_old = ops.matmul(a, b_ops) diff --git a/mindnlp/quant/mindbnb/tests/test_mindbnb_linear.py b/mindnlp/quant/mindbnb/tests/test_mindbnb_linear.py new file mode 100644 index 000000000..9da729cb0 --- /dev/null +++ b/mindnlp/quant/mindbnb/tests/test_mindbnb_linear.py @@ -0,0 +1,51 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import mindspore.context +import numpy as np +import mindspore +from mindspore import Tensor +from mindnlp.core import nn +from bitsandbytes.nn import Linear8bitLt + + +np.random.seed(42) +mindspore.set_seed(42) +mindspore.context.set_context(device_target="GPU") + +int8_model = Linear8bitLt(4, 2, has_fp16_weights=False) + +weight = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=mindspore.float16) +bias = Tensor([1, 2], dtype=mindspore.float16) + +int8_model.weight.assign_value(weight) +int8_model.bias.assign_value(bias) + +int8_model.quant() +for name, param in int8_model.parameters_and_names(): + print(name) + print(param) + print(param.data.asnumpy()) + + +input_data = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=mindspore.float16) + +int8_output = int8_model(input_data) + +print(int8_output) diff --git a/mindnlp/quant/mindbnb/tests/test_model.py b/mindnlp/quant/mindbnb/tests/test_model.py new file mode 100644 index 000000000..811d18b20 --- /dev/null +++ b/mindnlp/quant/mindbnb/tests/test_model.py @@ -0,0 +1,47 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import os +import mindspore + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from integrations.quantization_bnb_8bit import quant_8bit +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer + +mindspore.set_context(device_target="GPU") + +tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b") +model = AutoModelForCausalLM.from_pretrained("Rocketknight1/falcon-rw-1b") +model.set_train(False) +# pdb.set_trace() +# for name, param in model.parameters_and_names(): +# print(name) +# print(param) +# print(param.data.asnumpy()) +# quantization +quant_8bit(model) +# for name, param in model.parameters_and_names(): +# print(name) +# print(param) +# print(param.data.asnumpy()) + +# pdb.set_trace() +inputs = tokenizer("My favorite food is", return_tensors="ms") +# pdb.set_trace() +output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=10) +output_str = tokenizer.batch_decode(output_ids)[0] +print(output_ids) +print(output_str)