-
Notifications
You must be signed in to change notification settings - Fork 0
06_quantized_forward_pass_functions_.md
Welcome back to the TinyQ tutorial! So far, we've learned about the Quantizer
that starts the process (Chapter 1), the different methods like W8A32 and W8A16 (Chapter 2), the special Custom Quantized Layers built to handle lower precision, how TinyQ replaces standard layers with these custom ones (Chapter 4), and the Weight Quantization Math that converts float32
weights to int8
.
Now, it's time to see the quantized model in action! Once the model's structure has been changed and the weights are stored in their new, compressed format within the custom layers, how does the model actually use this information to process new data? This is handled by the Quantized Forward Pass Functions.
In any neural network, the "forward pass" is the process of taking an input (like a sentence to be translated or an image to be classified) and pushing it through the network layer by layer to produce an output (like the translated sentence or the image's class). Each layer performs a calculation (often a matrix multiplication and addition, especially in linear layers) on the input it receives and passes the result to the next layer.
For a standard nn.Linear
layer with float32
weights and float32
input, the forward pass calculation is straightforward: output = input @ weights.T + bias
, all done using float32
math.
Our Custom Quantized Layers (like W8A32LinearLayer
and W8A16LinearLayer
) don't store weights as float32
. They store them as int8
! And the input activation they receive might be float32
or float16
, depending on the quantization method.
Standard PyTorch operations like torch.matmul
are optimized for specific data types (float32
, float16
). You can't directly multiply an int8
tensor by a float32
tensor and expect the hardware to magically handle it efficiently and correctly based on the scales and zero points used during quantization.
We need specific calculation logic that knows how to take the int8
weights, the scales, the zero points (if applicable), and the input activation (in its specified precision) and combine them correctly to produce the output. This specialized logic lives within the forward()
method of our custom layers, and often relies on dedicated quantized forward pass functions.
These functions define the specific steps needed to perform the equivalent of input @ weights.T + bias
but using the compressed int8
weight representation.
TinyQ provides utility functions in utils.py
that implement the core calculation logic for the different quantization methods:
-
w8_a32_forward
: Implements the forward pass for the W8A32 method (8-bit weights, 32-bit activations). -
w8_a16_forward
: Implements the forward pass for the W8A16 method (8-bit weights, 16-bit activations).
The forward()
method of the corresponding Custom Quantized Layers (W8A32LinearLayer
and W8A16LinearLayer
) simply call these functions, passing the necessary data (input, quantized weights, scales, bias).
Let's look at how each one works.
This function is designed for layers where weights are int8
but inputs and outputs should be float32
.
Its main job is to "bring back" the precision of the weights just in time for the calculation, using the stored scales. This is sometimes called "dequantizing on the fly".
Here's the conceptual process:
- Take the
float32
input activation. - Take the
int8
quantized weights and theirfloat32
scales. -
Dequantize the
int8
weights back into approximatefloat32
values using the scales (and zero point, though it's 0 for symmetric quantization). The formula isdequantized_weight = scale * quantized_weight + zero_point
. Since the zero point is 0, this simplifies todequantized_weight = scale * quantized_weight.float()
. Note the.float()
cast is needed because the scale multiplication requires floating-point numbers. - Perform a standard
float32
matrix multiplication between thefloat32
input and the newly dequantizedfloat32
weights. - Add the
float32
bias (if present). - Return the
float32
output activation.
Why is this efficient? Even though the actual multiplication is float32
, the weights are stored as int8
. This means the model takes up significantly less memory. Loading and moving these int8
weights is much faster than float32
weights. The dequantization happens element-wise and can often be parallelized or even accelerated on specific hardware, and the subsequent float32
matrix multiplication can use highly optimized standard kernels.
Let's look at the simplified code for w8_a32_forward
from utils.py
:
# From utils.py
def w8_a32_forward(input, q_w, s_w, z_w=0, bias=None):
"""W8A32 forward pass implementation."""
# Input is expected to be FP32
assert input.dtype == torch.float32
# Quantized weights are INT8
assert q_w.dtype == torch.int8
# Ensure scale is the right shape for broadcasting ([out_features, 1])
s_w = s_w.view(-1, 1)
# Dequantize the INT8 weights back to FP32 using scales and zero point (which is 0)
# Calculation: dequantized_weight = q_w * s_w + z_w
dequantized_weight = q_w.to(torch.float32) * s_w + z_w # z_w is 0 here
# Perform the standard FP32 linear operation (matrix multiplication + bias)
# F.linear expects input * weight.T, so we use the dequantized weight directly
output = F.linear(input, dequantized_weight)
# Add bias if it exists
if bias is not None:
output += bias.squeeze() # Ensure bias is 1D or correct shape
return output # Output is FP32
This code directly implements the steps: cast q_w
to float32
, multiply by s_w
, add z_w
(which is 0), and then perform the F.linear
operation. The input remains float32
throughout this process.
This function is designed for layers where weights are int8
but inputs and outputs should be 16-bit floating-point (float16
or bfloat16
).
This method aims for greater speed by keeping the core matrix multiplication in lower precision.
Here's the conceptual process:
- Take the
float16
input activation. - Take the
int8
quantized weights and theirfloat16
scales. -
Cast the
int8
weights directly tofloat16
. These are not dequantized yet; they are just the integer values represented asfloat16
. - Perform a
float16
matrix multiplication between thefloat16
input and the castedfloat16
weights. - Multiply the result of the matrix multiplication by the
float16
scales. This effectively applies the scale after the multiplication. - Add the
float16
bias (if present). - Return the
float16
output activation.
Why is this efficient? The critical matrix multiplication step happens using 16-bit numbers. Modern hardware, especially GPUs, has dedicated cores (like Tensor Cores) that can perform 16-bit matrix multiplications much faster than float32
ones. The scale multiplication and bias addition are simpler element-wise operations that are also fast in 16-bit.
Let's look at the simplified code for w8_a16_forward
from utils.py
:
# From utils.py
import torch.nn.functional as F # Need F.linear
def w8_a16_forward(weight, input, scales, bias=None):
"""W8A16 forward pass implementation."""
# Input is expected to be FP16 (or BFP16)
assert input.dtype == torch.float16 or input.dtype == torch.bfloat16
# Quantized weights are INT8
assert weight.dtype == torch.int8
# Scales should match input dtype
assert scales.dtype == input.dtype
# Cast INT8 weights to the same FP16/BFP16 dtype as the input
# Note: This is *not* dequantization yet. It's just changing the data type.
casted_weights = weight.to(input.dtype)
# Perform the core linear operation (matrix multiplication) in FP16/BFP16
output = F.linear(input=input, weight=casted_weights)
# Apply the scale *after* the matrix multiplication
output = output * scales.unsqueeze(0) # Unsqueeze scale for broadcasting
# Add bias if it exists
if bias is not None:
output += bias.squeeze() # Ensure bias is 1D or correct shape
return output # Output is FP16/BFP16
Here, the int8
weights are simply cast to the input's dtype (float16
). The F.linear
call operates on float16
data, and the scale is applied after this multiplication. This utilizes the 16-bit computation capabilities of the hardware.
The beauty of the Custom Quantized Layers is that their forward
methods are very simple. They just call the appropriate helper function and pass the data they hold and the input they receive.
From tinyq.py
:
# Inside W8A32LinearLayer class
def forward(self, input):
# W8A32 layer calls the W8A32 forward function
return w8_a32_forward(input=input,
q_w=self.int8_weights,
s_w=self.scales,
z_w=0, # Zero point is 0 for symmetric
bias=self.bias)
# Inside W8A16LinearLayer class
def forward(self, input):
# W8A16 layer calls the W8A16 forward function
return w8_a16_forward(weight=self.int8_weights,
input=input, # Input is FP16/BFP16
scales=self.scales, # Scales are FP16/BFP16
bias=self.bias)
When you run your quantized model on new input data, PyTorch calls the forward()
method of each layer in sequence. When it reaches one of our custom layers, its forward()
method calls the corresponding wX_aY_forward
function, which then performs the efficient, quantized calculation using the stored int8
weights and scales.
Here's a simple flow of how the quantized forward pass works for a single layer after the model has been quantized:
sequenceDiagram
participant Input as Input Activation (FP32/FP16)
participant CustomLayer as Custom Quantized Layer<br/>(W8A32LinearLayer / W8A16LinearLayer)
participant QForwardFunc as Quantized Forward Function<br/>(w8_a32_forward / w8_a16_forward)
participant Computation as Computation Logic
participant Output as Output Activation (FP32/FP16)
Input->CustomLayer: forward(input)
CustomLayer->QForwardFunc: Call wX_aY_forward(<br/>input,<br/>int8_weights,<br/>scales,<br/>bias)
QForwardFunc->Computation: Perform Calculation<br/>(Dequantize/Cast, MatMul, Scale, Bias) using:<br/>- input (FP32/FP16)<br/>- int8_weights<br/>- scales (FP32/FP16)<br/>- bias (FP32/FP16)
Computation-->QForwardFunc: Return result
QForwardFunc-->CustomLayer: Return result
CustomLayer-->Output: Pass result to next layer
Feature |
w8_a32_forward (W8A32) |
w8_a16_forward (W8A16) |
---|---|---|
Input Dtype | float32 |
float16 or bfloat16
|
Weight Dtype |
int8 (stored), float32 (for math) |
int8 (stored), float16 /bfloat16 (for math) |
Scale Dtype | float32 |
float16 or bfloat16
|
Zero Point | 0 (used in calculation) | Implicitly 0 (not explicitly used in code) |
Core Math Dtype |
float32 (matrix multiplication) |
float16 or bfloat16 (matrix multiplication) |
Scale Applied | Before matrix multiplication (part of dequantization) | After matrix multiplication |
Benefit | Saves weight memory; uses optimized FP32 kernels | Saves weight & activation memory; uses fast FP16/BFP16 kernels |
The choice between W8A32 and W8A16 forward passes depends on your hardware capabilities (whether it accelerates 16-bit math) and the desired trade-off between potential speedup and potential accuracy loss.
The Quantized Forward Pass Functions (w8_a32_forward
and w8_a16_forward
) are the critical piece that makes the quantized model runnable. They provide the specialized computation logic needed to perform matrix multiplications using int8
weights and their corresponding scales, while handling input activations in float32
or float16
.
These functions are called internally by the forward()
methods of the Custom Quantized Layers that replaced the original nn.Linear
layers. This ensures that during inference, the model correctly utilizes the compressed weight format and leverages lower-precision computations where appropriate, leading to reduced memory usage and potentially faster execution.
Now that we've covered how the model is quantized and how it performs calculations, the final chapter will look at how to handle these quantized models, including saving, loading, and using them in practice.