From ff2153c4aa267441854cb3db74565d044d2ef3d1 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 1 Jul 2025 08:18:50 -0700 Subject: [PATCH] [not for land] towards QAT with exact forward pass Summary: Exploration for QAT with exact (not emulated) forward pass, WIP and not ready for review. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/prototype/qat_exact/main.py | 211 ++++++++++++ torchao/prototype/qat_exact/reference_gemm.py | 132 ++++++++ torchao/prototype/qat_exact/triton_gemm.py | 299 ++++++++++++++++++ 3 files changed, 642 insertions(+) create mode 100644 torchao/prototype/qat_exact/main.py create mode 100644 torchao/prototype/qat_exact/reference_gemm.py create mode 100644 torchao/prototype/qat_exact/triton_gemm.py diff --git a/torchao/prototype/qat_exact/main.py b/torchao/prototype/qat_exact/main.py new file mode 100644 index 0000000000..057aed04a4 --- /dev/null +++ b/torchao/prototype/qat_exact/main.py @@ -0,0 +1,211 @@ +""" +Prototype of QAT with exact (instead of emulated) forward pass using +integer matrix multiply. + +Quant spec: +* int4 symmetric weights w/ group size 32 or 256, +* int8 asymmetric per-token dynamic activations + +""" + +import copy + +import fire +import torch +import torch.nn as nn + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.qat_exact.reference_gemm import ( + cpu_x_token_assym_fp8_w_group_sym_int4_gemm, + naive_x_token_assym_fp8_w_group_sym_int4_gemm, +) +from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton +from torchao.quantization import quantize_ +from torchao.quantization.qat import ( + FakeQuantizeConfig, + IntXQuantizationAwareTrainingConfig, +) +from torchao.quantization.qat.fake_quantizer import FakeQuantizer +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, +) +from torchao.quantization.utils import ( + _get_per_token_block_size, +) + +torch.manual_seed(0) + + +def quantize_x(x_fp32): + # Dynamic quantization of activation + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8] + x_eps = torch.finfo(torch.float32).eps + x_scales_type = torch.float32 + x_zero_points_type = torch.int32 + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + x_scales_type, + x_zero_points_type, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + return x_i8, x_scale, x_zero_point + + +class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + gemm_mode = kwargs.pop("gemm_mode") + assert gemm_mode in ( + "int8_naive_reference", + "int8_cpu_reference", + "int8_triton", + ) + super().__init__(*args, **kwargs) + # manually create fake quantizer configs + activation_config = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + weight_config = FakeQuantizeConfig(torch.int4, group_size=32) + + # manually create fake quantizers + # reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py) + self.activation_fq = FakeQuantizer(activation_config) + self.weight_fq = FakeQuantizer(weight_config) + self.gemm_mode = gemm_mode + + def forward(self, input): + # quantize x + input_i8, input_scale, input_zp = quantize_x(input) + + # quantize w + _ = self.weight_fq(self.weight) + w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4] + w_granularity = self.weight_fq.config.granularity + w_group_size = w_granularity.group_size + w_block_size = (1, w_group_size) + weight_int4 = torch.ops.torchao.quantize_affine( + self.weight, + w_block_size, + self.weight_fq.scale, + self.weight_fq.zero_point, + torch.int8, + w_qmin, + w_qmax, + ) + + if self.gemm_mode == "int8_naive_reference": + # original reference + q_output = naive_x_token_assym_fp8_w_group_sym_int4_gemm( + input_i8.to(torch.int32), + input_scale, + input_zp, + weight_int4.to(torch.int32), + self.weight_fq.scale, + w_group_size, + ) + elif self.gemm_mode == "int8_cpu_reference": + # now also check Kimish's implementation + q_output = cpu_x_token_assym_fp8_w_group_sym_int4_gemm( + input_i8.cpu(), + input_scale.cpu(), + input_zp.cpu(), + weight_int4.cpu(), + self.weight_fq.scale.cpu(), + self.weight_fq.zero_point.cpu(), + self.bias, + self.weight_fq.config.granularity.group_size, + ).cuda() + elif self.gemm_mode == "int8_triton": + # finally, check vs triton gemm + q_output = int8_matmul_triton( + input_i8, + weight_int4.t(), + input_scale, + input_zp, + self.weight_fq.scale.t(), + w_group_size, + ) + + return q_output + + @classmethod + def from_float(cls, mod: torch.nn.Linear, gemm_mode: str): + new_mod = cls(mod.in_features, mod.out_features, gemm_mode=gemm_mode) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + return new_mod + + +def run(): + M, K, N = 32, 64, 128 + + # TODO(before land): also implement bias=True + m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda() + mq_ref = copy.deepcopy(m_hp) + mq_naive = copy.deepcopy(m_hp) + mq_cpu = copy.deepcopy(m_hp) + mq_triton = copy.deepcopy(m_hp) + + # create a baseline: QAT with fake quants. Our exact QAT's output should + # be close to this + activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + weight_config = FakeQuantizeConfig(torch.int4, group_size=32) + quantize_( + mq_ref, + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), + ) + + # create the experiment: forward pass with an integer gemm + mq_naive[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float( + mq_naive[0], "int8_naive_reference" + ) + mq_cpu[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float( + mq_cpu[0], "int8_cpu_reference" + ) + mq_triton[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float( + mq_triton[0], "int8_triton" + ) + + x_hp = torch.randn(M, K, device="cuda") + xq_ref = copy.deepcopy(x_hp) + xq = copy.deepcopy(x_hp) + + with torch.no_grad(): + y_hp = m_hp(x_hp) + yq_ref = mq_ref(xq_ref) + yq_naive = mq_naive(xq) + yq_cpu = mq_cpu(xq) + yq_triton = mq_triton(xq) + + sqnr_hp_qref = compute_error(y_hp, yq_ref) + sqnr_hp_qnaive = compute_error(y_hp, yq_naive) + sqnr_qref_qnaive = compute_error(yq_ref, yq_naive) + sqnr_qcpu_qnaive = compute_error(yq_cpu, yq_naive) + sqnr_qcpu_qtriton = compute_error(yq_cpu, yq_triton) + sqnr_qnaive_qtriton = compute_error(yq_naive, yq_triton) + print("sqnr_hp_qref", sqnr_hp_qref) + print("sqnr_hp_qnaive", sqnr_hp_qnaive) + print("sqnr_qref_qnaive", sqnr_qref_qnaive) + print("sqnr_qcpu_qnaive", sqnr_qcpu_qnaive) + print("sqnr_qcpu_triton", sqnr_qcpu_qtriton) + print("sqnr_qnaive_qtriton", sqnr_qnaive_qtriton) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/torchao/prototype/qat_exact/reference_gemm.py b/torchao/prototype/qat_exact/reference_gemm.py new file mode 100644 index 0000000000..eba5ca34f8 --- /dev/null +++ b/torchao/prototype/qat_exact/reference_gemm.py @@ -0,0 +1,132 @@ +import torch +from torch._higher_order_ops.out_dtype import out_dtype + + +def cpu_x_token_assym_fp8_w_group_sym_int4_gemm( + x_i8, + x_scale, + x_zero_point, + weight_int4, + weight_scale, + weight_zero_point, + bias_fp32, + group_size, +): + # For groupwise quantization, we need to handle the computation differently + # weight_i4 shape: [out_features, in_features] + # weight_scale shape: [out_features, in_features // group_size] + # weight_zero_point shape: [out_features, in_features // group_size] + out_features, in_features = weight_int4.shape + num_groups = in_features // group_size + + # scales in xnnpack are stored as bf16 and converted to fp32 for computation + weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32) + + assert x_i8.dim() == 2, "x_i8 must be 2D tensor" + # Reshape for group-wise processing + # x: [batch_size, in_features] -> [batch_size, num_groups, group_size] + batch_size = x_i8.shape[0] + x_i8_grouped = x_i8.view(batch_size, num_groups, group_size) + + # weight: [out_features, in_features] -> [out_features, num_groups, group_size] + weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size) + + # Convert to int16 for computation + x_i32_grouped = x_i8_grouped.to(torch.int32) + weight_i32_grouped = weight_i4_grouped.to(torch.int32) + + # Perform groupwise integer linear operation + acc_fp32 = torch.zeros( + batch_size, out_features, dtype=torch.float32, device=x_i8.device + ) + + for group_idx in range(num_groups): + # Extract current group + x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size] + weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size] + weight_group_col_sum = weight_group.sum(dim=-1) # [out_features] + + # Get scale for this group + weight_scale_group = weight_scale[:, group_idx] # [out_features] + + # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features] + group_acc = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_group, + weight_group, + None, + ) + + # Output has to be scaled by x_scale * weight_scale_group + # However we will first scale by weight_scale_group, that is accounting + # only for scale of weight, and then scale by x_scale at the end because + # x_scale applies to all groups + acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view( + 1, -1 + ) + + # we must also subtract x_zero_point * weight_group_sum + # since (X - x_zero_point) * W = X * W - x_zero_point * W + weights_col_sum_adjusted = ( + weight_group_col_sum.to(torch.float32).view(1, -1) + * x_zero_point.view(-1, 1) + * weight_scale_group.view(1, -1) + ) + acc_fp32 = acc_fp32 - weights_col_sum_adjusted + x_scale_multiplier = x_scale.view(-1, 1) + out_fp32 = acc_fp32 * x_scale_multiplier + if bias_fp32 is not None: + out_fp32 = out_fp32 + bias_fp32 + + return out_fp32 + + +def naive_x_token_assym_fp8_w_group_sym_int4_gemm( + act_q, + act_scale, + act_zp, + w_q, + w_scale, + w_group_size, +) -> torch.Tensor: + # + # now we have the scales/zero_points/quant values for both gemm operands + # below is a manual slow gemm with integer operands and float rescaling, + # implemented using eager PyTorch ops. This should be slow but closely + # (but not exactly) matching a real int8,int8->int32 gemm with + # rescaling, with the only difference being that the sum inside of the + # dot product is done in float32 right now. + # + q_output = torch.zeros( + act_q.shape[0], + w_q.shape[0], + dtype=torch.float32, + device=act_q.device, + ) + for m_idx in range(act_q.shape[0]): + for n_idx in range(w_q.shape[0]): + for g_idx in range(w_q.shape[1] // w_group_size): + k_start = g_idx * w_group_size + k_end = k_start + w_group_size + act_chunk = act_q[m_idx][k_start:k_end] + w_chunk = w_q[n_idx][k_start:k_end] + + # (act_q - act_zp) * w_q + # result still in int32 + elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk + + # sum((act_q - act_zp) * w_q) + # this is in float32, so likely a small deviation from the real + # kernel, where the entire dot product would be in int32 + sum_float32 = torch.sum(elem_int32) + + # scale + act_scale_tmp = act_scale[m_idx].squeeze(-1) + w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float() + sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp + + # accumulate + q_output[m_idx][n_idx] += sum_scaled + + return q_output diff --git a/torchao/prototype/qat_exact/triton_gemm.py b/torchao/prototype/qat_exact/triton_gemm.py new file mode 100644 index 0000000000..5fbbaa0276 --- /dev/null +++ b/torchao/prototype/qat_exact/triton_gemm.py @@ -0,0 +1,299 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def int8_matmul_kernel_precise( + # Input pointers + a_ptr, + b_ptr, + c_ptr, + # Quantization parameters + a_scale_ptr, + a_zero_ptr, # Per-token asymmetric scaling for activations + b_scale_ptr, # Per-group symmetric scaling for weights (no zero point) + # Matrix dimensions + M, + N, + K, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Block sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + """ + More precise version that handles per-group scaling correctly. + This version processes one group at a time to apply correct per-group scales. + """ + + # Get program ID + pid = tl.program_id(axis=0) + + # Calculate 2D grid coordinates + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Calculate block offsets + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + # Load per-token quantization parameters for A + a_scale = tl.load(a_scale_ptr + offs_am) # [BLOCK_SIZE_M] + a_zero = tl.load(a_zero_ptr + offs_am) # [BLOCK_SIZE_M] + + # Initialize final accumulator + final_accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Process each group separately to apply correct per-group scaling + num_groups = tl.cdiv(K, GROUP_SIZE) + + for group_idx in range(num_groups): + # Calculate K range for this group + k_start = group_idx * GROUP_SIZE + k_end = tl.minimum(k_start + GROUP_SIZE, K) + actual_group_size = k_end - k_start + + # Load B scale for this group + b_scale_offs = group_idx * N + offs_bn + b_scale_mask = offs_bn < N + b_scale = tl.load(b_scale_ptr + b_scale_offs, mask=b_scale_mask, other=1.0) + # f32 -> bf16 -> f32 to match xnnpack + b_scale = b_scale.to(tl.bfloat16).to(tl.float32) + + # Process this group in sub-blocks + group_accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + group_zero_correction = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + + # Process the group in BLOCK_SIZE_K chunks + for k_sub in range(0, actual_group_size, BLOCK_SIZE_K): + k_offset = k_start + k_sub + + # Calculate offsets for this sub-block + offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K) + + # Load A block + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] < K) + a_block = tl.load(a_ptrs, mask=a_mask, other=0) + + # Load B block + b_ptrs = b_ptr + ( + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + b_mask = (offs_k[:, None] < K) & (offs_bn[None, :] < N) + b_block = tl.load(b_ptrs, mask=b_mask, other=0) + + # Integer dot product using tensor cores + group_accumulator = tl.dot(a_block, b_block, acc=group_accumulator) + + # Zero-point correction for this sub-block + a_zero_expanded = a_zero[:, None].to(tl.int32) + b_sum = tl.sum(b_block.to(tl.int32), axis=0, keep_dims=True) + group_zero_correction += a_zero_expanded * b_sum + + # Apply zero-point correction for this group + corrected_group_result = group_accumulator - group_zero_correction + + # Convert to float and apply scaling for this group + group_result_float = corrected_group_result.to(tl.float32) + + # Apply combined scaling: A_scale * B_scale + a_scale_expanded = a_scale[:, None] + b_scale_expanded = b_scale[None, :] + combined_scale = a_scale_expanded * b_scale_expanded + + # Add this group's contribution to final result + final_accumulator += group_result_float * combined_scale + + # Store final result + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, final_accumulator, mask=c_mask) + + +def int8_matmul_triton( + a_int8: torch.Tensor, # [M, K] int8 activations + b_int8: torch.Tensor, # [K, N] int8 weights + a_scale: torch.Tensor, # [M] fp32 per-token scales + a_zero: torch.Tensor, # [M] int8 per-token zero points + b_scale: torch.Tensor, # [K//32, N] fp16 per-group scales + group_size: int = 32, +) -> torch.Tensor: + """ + Wrapper function for the Triton int8 matmul kernel using integer tensor cores. + + Args: + a_int8: Quantized activations [M, K] (int8) + b_int8: Quantized weights [K, N] (int8) + a_scale: Per-token scale factors for activations [M] (fp16) + a_zero: Per-token zero points for activations [M] (int8) + b_scale: Per-group scale factors for weights [K//group_size, N] (fp16) + group_size: Group size for weight quantization (default: 32) + + Returns: + Output tensor [M, N] (fp16) + """ + M, K = a_int8.shape + K_b, N = b_int8.shape + + assert K == K_b, f"K dimensions must match: {K} != {K_b}" + assert K % group_size == 0, ( + f"K ({K}) must be divisible by group_size ({group_size})" + ) + assert a_scale.shape == (M,), f"a_scale shape mismatch: {a_scale.shape} != {(M,)}" + assert a_zero.shape == (M,), f"a_zero shape mismatch: {a_zero.shape} != {(M,)}" + assert b_scale.shape == (K // group_size, N), ( + f"b_scale shape mismatch: {b_scale.shape} != {(K // group_size, N)}" + ) + + # Output tensor + c = torch.empty((M, N), device=a_int8.device, dtype=torch.float32) + + # Flatten b_scale for easier indexing in kernel + b_scale_flat = b_scale.reshape(-1) + + # Launch configuration - optimized for integer tensor cores + # Use smaller blocks for better tensor core utilization + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 # Should be multiple of 16 for tensor cores + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + # Launch kernel + int8_matmul_kernel_precise[grid]( + a_int8, + b_int8, + c, + a_scale, + a_zero, + b_scale_flat, + M, + N, + K, + a_int8.stride(0), + a_int8.stride(1), + b_int8.stride(0), + b_int8.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE=group_size, + ) + + return c + + +# Utility function to benchmark tensor core utilization +def benchmark_tensor_cores(): + """ + Benchmark function to demonstrate tensor core utilization. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test different matrix sizes + test_sizes = [ + (256, 256, 256), + (512, 512, 512), + (1024, 1024, 1024), + (2048, 2048, 2048), + ] + + for M, N, K in test_sizes: + print(f"\nTesting {M}x{N}x{K} matrix multiplication") + + # Generate test data + a_int8 = torch.randint(-128, 127, (M, K), dtype=torch.int8, device=device) + b_int8 = torch.randint(-128, 127, (K, N), dtype=torch.int8, device=device) + a_scale = torch.rand(M, device=device) * 0.1 + a_zero = torch.randint(0, 255, (M,), dtype=torch.uint8, device=device) + b_scale = torch.rand(K // 32, N, device=device) * 0.1 + + # Warmup + for _ in range(5): + _ = int8_matmul_triton(a_int8, b_int8, a_scale, a_zero, b_scale) + + # Benchmark + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(10): + result = int8_matmul_triton(a_int8, b_int8, a_scale, a_zero, b_scale) + end.record() + torch.cuda.synchronize() + + elapsed_time = start.elapsed_time(end) / 10 # Average time per iteration + + # Calculate TOPS (Tera Operations Per Second) + ops = 2 * M * N * K # Multiply-add operations + tops = (ops / (elapsed_time * 1e-3)) / 1e12 + + print(f" Time: {elapsed_time:.2f} ms") + print(f" TOPS: {tops:.2f}") + print(f" Result shape: {result.shape}") + + +# Example usage +if __name__ == "__main__": + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Matrix dimensions + M, N, K = 512, 512, 1024 + group_size = 32 + + # Create test data + torch.manual_seed(42) + + # Generate random int8 matrices + a_int8 = torch.randint(-128, 127, (M, K), dtype=torch.int8, device=device) + b_int8 = torch.randint(-128, 127, (K, N), dtype=torch.int8, device=device) + + # Create quantization parameters + a_scale = torch.rand(M, device=device) * 0.1 + a_zero = torch.randint(0, 255, (M,), dtype=torch.uint8, device=device) + b_scale = torch.rand(K // group_size, N, device=device) * 0.1 + + print("Input shapes:") + print(f" a_int8: {a_int8.shape} ({a_int8.dtype})") + print(f" b_int8: {b_int8.shape} ({b_int8.dtype})") + print(f" a_scale: {a_scale.shape} ({a_scale.dtype})") + print(f" a_zero: {a_zero.shape} ({a_zero.dtype})") + print(f" b_scale: {b_scale.shape} ({b_scale.dtype})") + + # Test both kernels + print("Testing precise kernel...") + result_precise = int8_matmul_triton( + a_int8, b_int8, a_scale, a_zero, b_scale, group_size + ) + + print("\nOutput shapes:") + print(f" Precise kernel result: {result_precise.shape} ({result_precise.dtype})") + + # Run benchmark + if torch.cuda.is_available(): + print("\nRunning tensor core benchmark...") + benchmark_tensor_cores()