Skip to content

06_quantized_forward_pass_functions_.md

Afonso Diela edited this page Jun 19, 2025 · 1 revision

Chapter 6: Quantized Forward Pass Functions

Welcome back to the TinyQ tutorial! So far, we've learned about the Quantizer that starts the process (Chapter 1), the different methods like W8A32 and W8A16 (Chapter 2), the special Custom Quantized Layers built to handle lower precision, how TinyQ replaces standard layers with these custom ones (Chapter 4), and the Weight Quantization Math that converts float32 weights to int8.

Now, it's time to see the quantized model in action! Once the model's structure has been changed and the weights are stored in their new, compressed format within the custom layers, how does the model actually use this information to process new data? This is handled by the Quantized Forward Pass Functions.

What is a Forward Pass?

In any neural network, the "forward pass" is the process of taking an input (like a sentence to be translated or an image to be classified) and pushing it through the network layer by layer to produce an output (like the translated sentence or the image's class). Each layer performs a calculation (often a matrix multiplication and addition, especially in linear layers) on the input it receives and passes the result to the next layer.

For a standard nn.Linear layer with float32 weights and float32 input, the forward pass calculation is straightforward: output = input @ weights.T + bias, all done using float32 math.

Why Do We Need Quantized Forward Pass Functions?

Our Custom Quantized Layers (like W8A32LinearLayer and W8A16LinearLayer) don't store weights as float32. They store them as int8! And the input activation they receive might be float32 or float16, depending on the quantization method.

Standard PyTorch operations like torch.matmul are optimized for specific data types (float32, float16). You can't directly multiply an int8 tensor by a float32 tensor and expect the hardware to magically handle it efficiently and correctly based on the scales and zero points used during quantization.

We need specific calculation logic that knows how to take the int8 weights, the scales, the zero points (if applicable), and the input activation (in its specified precision) and combine them correctly to produce the output. This specialized logic lives within the forward() method of our custom layers, and often relies on dedicated quantized forward pass functions.

These functions define the specific steps needed to perform the equivalent of input @ weights.T + bias but using the compressed int8 weight representation.

TinyQ's Quantized Forward Pass Functions

TinyQ provides utility functions in utils.py that implement the core calculation logic for the different quantization methods:

  1. w8_a32_forward: Implements the forward pass for the W8A32 method (8-bit weights, 32-bit activations).
  2. w8_a16_forward: Implements the forward pass for the W8A16 method (8-bit weights, 16-bit activations).

The forward() method of the corresponding Custom Quantized Layers (W8A32LinearLayer and W8A16LinearLayer) simply call these functions, passing the necessary data (input, quantized weights, scales, bias).

Let's look at how each one works.

Method 1: w8_a32_forward (W8A32)

This function is designed for layers where weights are int8 but inputs and outputs should be float32.

Its main job is to "bring back" the precision of the weights just in time for the calculation, using the stored scales. This is sometimes called "dequantizing on the fly".

Here's the conceptual process:

  1. Take the float32 input activation.
  2. Take the int8 quantized weights and their float32 scales.
  3. Dequantize the int8 weights back into approximate float32 values using the scales (and zero point, though it's 0 for symmetric quantization). The formula is dequantized_weight = scale * quantized_weight + zero_point. Since the zero point is 0, this simplifies to dequantized_weight = scale * quantized_weight.float(). Note the .float() cast is needed because the scale multiplication requires floating-point numbers.
  4. Perform a standard float32 matrix multiplication between the float32 input and the newly dequantized float32 weights.
  5. Add the float32 bias (if present).
  6. Return the float32 output activation.

Why is this efficient? Even though the actual multiplication is float32, the weights are stored as int8. This means the model takes up significantly less memory. Loading and moving these int8 weights is much faster than float32 weights. The dequantization happens element-wise and can often be parallelized or even accelerated on specific hardware, and the subsequent float32 matrix multiplication can use highly optimized standard kernels.

Let's look at the simplified code for w8_a32_forward from utils.py:

# From utils.py

def w8_a32_forward(input, q_w, s_w, z_w=0, bias=None):
    """W8A32 forward pass implementation."""
    # Input is expected to be FP32
    assert input.dtype == torch.float32
    # Quantized weights are INT8
    assert q_w.dtype == torch.int8
    
    # Ensure scale is the right shape for broadcasting ([out_features, 1])
    s_w = s_w.view(-1, 1) 
    
    # Dequantize the INT8 weights back to FP32 using scales and zero point (which is 0)
    # Calculation: dequantized_weight = q_w * s_w + z_w
    dequantized_weight = q_w.to(torch.float32) * s_w + z_w # z_w is 0 here

    # Perform the standard FP32 linear operation (matrix multiplication + bias)
    # F.linear expects input * weight.T, so we use the dequantized weight directly
    output = F.linear(input, dequantized_weight) 
    
    # Add bias if it exists
    if bias is not None:
        output += bias.squeeze() # Ensure bias is 1D or correct shape

    return output # Output is FP32

This code directly implements the steps: cast q_w to float32, multiply by s_w, add z_w (which is 0), and then perform the F.linear operation. The input remains float32 throughout this process.

Method 2: w8_a16_forward (W8A16)

This function is designed for layers where weights are int8 but inputs and outputs should be 16-bit floating-point (float16 or bfloat16).

This method aims for greater speed by keeping the core matrix multiplication in lower precision.

Here's the conceptual process:

  1. Take the float16 input activation.
  2. Take the int8 quantized weights and their float16 scales.
  3. Cast the int8 weights directly to float16. These are not dequantized yet; they are just the integer values represented as float16.
  4. Perform a float16 matrix multiplication between the float16 input and the casted float16 weights.
  5. Multiply the result of the matrix multiplication by the float16 scales. This effectively applies the scale after the multiplication.
  6. Add the float16 bias (if present).
  7. Return the float16 output activation.

Why is this efficient? The critical matrix multiplication step happens using 16-bit numbers. Modern hardware, especially GPUs, has dedicated cores (like Tensor Cores) that can perform 16-bit matrix multiplications much faster than float32 ones. The scale multiplication and bias addition are simpler element-wise operations that are also fast in 16-bit.

Let's look at the simplified code for w8_a16_forward from utils.py:

# From utils.py
import torch.nn.functional as F # Need F.linear

def w8_a16_forward(weight, input, scales, bias=None):
    """W8A16 forward pass implementation."""
    # Input is expected to be FP16 (or BFP16)
    assert input.dtype == torch.float16 or input.dtype == torch.bfloat16
    # Quantized weights are INT8
    assert weight.dtype == torch.int8
    # Scales should match input dtype
    assert scales.dtype == input.dtype

    # Cast INT8 weights to the same FP16/BFP16 dtype as the input
    # Note: This is *not* dequantization yet. It's just changing the data type.
    casted_weights = weight.to(input.dtype)
    
    # Perform the core linear operation (matrix multiplication) in FP16/BFP16
    output = F.linear(input=input, weight=casted_weights) 
    
    # Apply the scale *after* the matrix multiplication
    output = output * scales.unsqueeze(0) # Unsqueeze scale for broadcasting

    # Add bias if it exists
    if bias is not None:
        output += bias.squeeze() # Ensure bias is 1D or correct shape

    return output # Output is FP16/BFP16

Here, the int8 weights are simply cast to the input's dtype (float16). The F.linear call operates on float16 data, and the scale is applied after this multiplication. This utilizes the 16-bit computation capabilities of the hardware.

How Custom Layers Use These Functions

The beauty of the Custom Quantized Layers is that their forward methods are very simple. They just call the appropriate helper function and pass the data they hold and the input they receive.

From tinyq.py:

# Inside W8A32LinearLayer class

def forward(self, input):
    # W8A32 layer calls the W8A32 forward function
    return w8_a32_forward(input=input, 
                          q_w=self.int8_weights, 
                          s_w=self.scales, 
                          z_w=0, # Zero point is 0 for symmetric
                          bias=self.bias)

# Inside W8A16LinearLayer class

def forward(self, input):
    # W8A16 layer calls the W8A16 forward function
    return w8_a16_forward(weight=self.int8_weights, 
                          input=input, # Input is FP16/BFP16
                          scales=self.scales, # Scales are FP16/BFP16
                          bias=self.bias)

When you run your quantized model on new input data, PyTorch calls the forward() method of each layer in sequence. When it reaches one of our custom layers, its forward() method calls the corresponding wX_aY_forward function, which then performs the efficient, quantized calculation using the stored int8 weights and scales.

The Forward Pass Flow

Here's a simple flow of how the quantized forward pass works for a single layer after the model has been quantized:

sequenceDiagram
    participant Input as Input Activation (FP32/FP16)
    participant CustomLayer as Custom Quantized Layer<br/>(W8A32LinearLayer / W8A16LinearLayer)
    participant QForwardFunc as Quantized Forward Function<br/>(w8_a32_forward / w8_a16_forward)
    participant Computation as Computation Logic
    participant Output as Output Activation (FP32/FP16)

    Input->CustomLayer: forward(input)
    CustomLayer->QForwardFunc: Call wX_aY_forward(<br/>input,<br/>int8_weights,<br/>scales,<br/>bias)
    QForwardFunc->Computation: Perform Calculation<br/>(Dequantize/Cast, MatMul, Scale, Bias) using:<br/>- input (FP32/FP16)<br/>- int8_weights<br/>- scales (FP32/FP16)<br/>- bias (FP32/FP16)
    Computation-->QForwardFunc: Return result
    QForwardFunc-->CustomLayer: Return result
    CustomLayer-->Output: Pass result to next layer
Loading

Comparing the Forward Passes

Feature w8_a32_forward (W8A32) w8_a16_forward (W8A16)
Input Dtype float32 float16 or bfloat16
Weight Dtype int8 (stored), float32 (for math) int8 (stored), float16/bfloat16 (for math)
Scale Dtype float32 float16 or bfloat16
Zero Point 0 (used in calculation) Implicitly 0 (not explicitly used in code)
Core Math Dtype float32 (matrix multiplication) float16 or bfloat16 (matrix multiplication)
Scale Applied Before matrix multiplication (part of dequantization) After matrix multiplication
Benefit Saves weight memory; uses optimized FP32 kernels Saves weight & activation memory; uses fast FP16/BFP16 kernels

The choice between W8A32 and W8A16 forward passes depends on your hardware capabilities (whether it accelerates 16-bit math) and the desired trade-off between potential speedup and potential accuracy loss.

Conclusion

The Quantized Forward Pass Functions (w8_a32_forward and w8_a16_forward) are the critical piece that makes the quantized model runnable. They provide the specialized computation logic needed to perform matrix multiplications using int8 weights and their corresponding scales, while handling input activations in float32 or float16.

These functions are called internally by the forward() methods of the Custom Quantized Layers that replaced the original nn.Linear layers. This ensures that during inference, the model correctly utilizes the compressed weight format and leverages lower-precision computations where appropriate, leading to reduced memory usage and potentially faster execution.

Now that we've covered how the model is quantized and how it performs calculations, the final chapter will look at how to handle these quantized models, including saving, loading, and using them in practice.

Next Chapter: Model Handling & Utilities