diff --git a/setup.py b/setup.py index 88669e7b3b..8fc6a2771f 100644 --- a/setup.py +++ b/setup.py @@ -429,6 +429,7 @@ def get_extensions(): # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT found_col16 = False found_vec_ext = False + found_outer_vec = False print("ROCM_HOME", ROCM_HOME) hipblaslt_headers = list( glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")) @@ -441,12 +442,17 @@ def get_extensions(): found_col16 = True if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text: found_vec_ext = True + if "HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F" in text: + found_outer_vec = True if found_col16: extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") print("hipblaslt found extended col order enums") else: print("hipblaslt does not have extended col order enums") - if found_vec_ext: + if found_outer_vec: + extra_compile_args["cxx"].append("-DHIPBLASLT_OUTER_VEC") + print("hipblaslt found outer vec") + elif found_vec_ext: extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT") print("hipblaslt found vec ext") else: diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index bfaf6bf466..feff97f56a 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -362,7 +362,7 @@ ScalingType get_scaling_type( // Check for RowWise scaling if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { -#if defined(HIPBLASLT_VEC_EXT) +#if defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC) TORCH_CHECK( scale_a.is_contiguous() && scale_b.is_contiguous(), "Both scale_a and scale_b must be contiguous for RowWise scaling."); @@ -619,17 +619,25 @@ void _scaled_gemm( computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(HIPBLASLT_VEC_EXT) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later with HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } #else - // rowwise isn't supported using cublaslt or older hipblaslt + // rowwise isn't supported using older hipblaslt TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); #endif computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); +#if defined(HIPBLASLT_OUTER_VEC) + if (use_rowwise) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + } +#endif if (result_scale_ptr != nullptr) { computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); }