Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions backends/vulkan/_passes/fuse_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,302 @@ def fuse_into_linear_qcnw_node(
graph_module.graph.erase_node(dq_weight_node)


#########################
## linear_qta8a_qga4w ##
#########################


def _is_dequantize_affine_node(node: torch.fx.Node) -> bool:
"""Check if a node is a dequantize_affine operation."""
return (
node.op == "call_function"
and node.target is not None
and hasattr(node.target, "__name__")
and "dequantize_affine" in getattr(node.target, "__name__", "")
)


def _is_view_copy_node(node: torch.fx.Node) -> bool:
"""Check if a node is a view_copy operation."""
return (
node.op == "call_function"
and node.target is not None
and hasattr(node.target, "__name__")
and "view_copy" in getattr(node.target, "__name__", "")
)


def _validate_qta8a_qga4w_nodes(
input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument
) -> Optional[torch.fx.Node]:
"""
Validate input and weight nodes for QTA8A_QGA4W pattern.
Returns the actual input node (after handling view operations) or None if invalid.
"""
# Type checking - ensure we have torch.fx.Node objects
if not isinstance(weight_node, torch.fx.Node) or not isinstance(
input_node, torch.fx.Node
):
return None

# Input may be preprocessed with a view node
actual_input_node = input_node
if _is_view_copy_node(input_node):
actual_input_node = input_node.args[0]
if not isinstance(actual_input_node, torch.fx.Node):
return None

# Check if input is dequantized with dequantize_affine (from dynamic quantization)
if not _is_dequantize_affine_node(actual_input_node):
return None

# Check if weight is dequantized with dequantize_affine
if not _is_dequantize_affine_node(weight_node):
return None

return actual_input_node


def _extract_weight_params(
program: ExportedProgram, weight_node: torch.fx.Node
) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]:
"""Extract and validate weight parameters from dequantize_affine node."""
# Get the original quantized weight and quantization parameters
if len(weight_node.args) < 4:
return None

orig_weight = weight_node.args[0]
weight_scales = weight_node.args[2]
weight_zeros = weight_node.args[3]

# Type checking
if not isinstance(orig_weight, torch.fx.Node) or not is_param_node(
program, orig_weight
):
return None
if not isinstance(weight_scales, torch.fx.Node) or not is_param_node(
program, weight_scales
):
return None
if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node(
program, weight_zeros
):
return None

return orig_weight, weight_scales, weight_zeros


def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool:
"""Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
quant_min = weight_tensor.min().item()
quant_max = weight_tensor.max().item()
return quant_min >= -8 and quant_max <= 7


def _calculate_group_size(
orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor
) -> Optional[int]:
"""Calculate and validate group size from weight and scales tensors."""
out_features, in_features = orig_weight_tensor.shape

if len(weight_scales_tensor.shape) != 2:
return None

scales_out_features, num_groups = weight_scales_tensor.shape

if scales_out_features != out_features:
return None

group_size = in_features // num_groups
if in_features % group_size != 0:
return None

return group_size


def matches_linear_qta8a_qga4w_pattern(
program: ExportedProgram, node: torch.fx.Node
) -> Optional[Tuple[int, int]]:
"""
Checks if the nodes surrounding a linear node matches the pattern for dynamic
activation + grouped weight quantized linear (QTA8A_QGA4W).
This pattern involves:
1. Dynamic quantization of input activations (8-bit)
2. Grouped quantization of weights (4-bit with group size)
The expected pattern from Int8DynActInt4WeightQuantizer is:
scale, zero_point = choose_qparams_affine(input)
quantized_input = quantize_affine(input, scale, zero_point)
dequantized_input = dequantize_affine(quantized_input, ...)
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
output = linear(dequantized_input, dequantized_weight)
If the pattern matches, return (group_size, weight_bits), otherwise None.
"""
if not utils.is_linear_node(node):
return None

input_node = node.args[0]
weight_node = node.args[1]

# Validate nodes and get actual input node
actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node)
if actual_input_node is None:
return None

# Extract weight parameters
if not isinstance(weight_node, torch.fx.Node):
return None
weight_params = _extract_weight_params(program, weight_node)
if weight_params is None:
return None

orig_weight, weight_scales, weight_zeros = weight_params

# Get tensors to analyze the quantization scheme
orig_weight_tensor = get_param_tensor(program, orig_weight)
weight_scales_tensor = get_param_tensor(program, weight_scales)
weight_zeros_tensor = get_param_tensor(program, weight_zeros)

if not isinstance(orig_weight_tensor, torch.Tensor):
return None
if not isinstance(weight_scales_tensor, torch.Tensor):
return None
if not isinstance(weight_zeros_tensor, torch.Tensor):
return None

# Check if weight is quantized to 4 bits
if not _validate_4bit_quantization(orig_weight_tensor):
return None

# Calculate group size
group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor)
if group_size is None:
return None

# Verify this is 4-bit grouped quantization
weight_bits = 4

return group_size, weight_bits


def fuse_into_linear_qta8a_qga4w_node(
program: ExportedProgram,
graph_module: torch.fx.GraphModule,
linear_node: torch.fx.Node,
group_size: int,
weight_bits: int,
) -> None:
"""
Fuse the dynamic activation + grouped weight quantized linear pattern into
a single linear_qta8a_qga4w operator.
The pattern:
dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
output = linear(dequantized_input, dequantized_weight)
Becomes:
output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
weight, group_size, weight_scales, weight_zeros)
"""
dq_input_node = linear_node.args[0]
dq_weight_node = linear_node.args[1]

assert isinstance(dq_input_node, torch.fx.Node)

input_view_node = None
# Input may be preprocessed with a view node
if (
dq_input_node.op == "call_function"
and dq_input_node.target is not None
and hasattr(dq_input_node.target, "__name__")
and "view_copy" in getattr(dq_input_node.target, "__name__", "")
):
input_view_node = dq_input_node
dq_input_node = dq_input_node.args[0]
assert isinstance(dq_input_node, torch.fx.Node)

assert isinstance(dq_input_node, torch.fx.Node)
assert isinstance(dq_weight_node, torch.fx.Node)

# Get the quantized input and quantization parameters from the input dequantize_affine node
# Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
quantized_input = dq_input_node.args[0]
input_scale = dq_input_node.args[2] # scale is the 3rd argument
input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None

# Get the weight and its quantization parameters from dequantize_affine
# Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
orig_weight = dq_weight_node.args[0]
weight_scales = dq_weight_node.args[2]
weight_zeros = dq_weight_node.args[3]

# Pack the 4-bit weight tensor for efficient storage
assert isinstance(orig_weight, torch.fx.Node)
orig_weight_tensor = get_param_tensor(program, orig_weight)
assert isinstance(orig_weight_tensor, torch.Tensor)
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
utils.update_program_state_dict(
program,
orig_weight.name,
packed_weight_tensor,
)
# Update the metadata to reflect the new packed shape
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)

# Create the linear_qta8a_qga4w node
with graph_module.graph.inserting_before(linear_node):
linear_qta8a_qga4w_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
(
quantized_input, # quantized input (int8)
input_scale, # mat1_scale
input_zero_point, # mat1_zero_point
orig_weight, # mat2_data (packed 4-bit weights)
group_size, # group_size (int)
weight_scales, # weight_scales
weight_zeros, # weight_zeros
),
)

# Replace the linear node with the new fused node
linear_node.replace_all_uses_with(linear_qta8a_qga4w_node)

# Erase nodes in the correct order (users first, then dependencies)
graph_module.graph.erase_node(linear_node)
if input_view_node is not None:
graph_module.graph.erase_node(input_view_node)
graph_module.graph.erase_node(dq_weight_node)
graph_module.graph.erase_node(dq_input_node)


class FuseQuantizedOpsTransform(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.program = exported_program

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
# Check for linear_qcnw pattern (weight-only quantization)
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
if qcnw_details is not None:
qcnw_method, qcnw_nbits = qcnw_details
fuse_into_linear_qcnw_node(
self.program, graph_module, node, qcnw_method, qcnw_nbits
)
continue

# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
if qta8a_qga4w_details is not None:
group_size, weight_bits = qta8a_qga4w_details
fuse_into_linear_qta8a_qga4w_node(
self.program, graph_module, node, group_size, weight_bits
)
continue

graph_module.recompile()
dead_code_elimination_pass(graph_module)
Expand Down
89 changes: 89 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,95 @@ def linear_qcs4w(
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

########################
## linear_qta8a_qga4w ##
########################


def linear_qta8a_qga4w(
x_quantized: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights_4bit: torch.Tensor,
group_size: int,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
):
"""
Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).
Args:
x_quantized: Already quantized input tensor (int8, per-token quantized)
input_scale: Scale for per-token quantization of input (shape: [batch_size])
input_zero_point: Zero point for per-token quantization of input (shape: [batch_size])
weights_4bit: Packed 4-bit quantized weights
group_size: Group size for weight quantization (int)
weight_scales: Per-group scales for weights
weight_zeros: Per-group zero points for weights
"""
original_x_shape = x_quantized.shape
feature_dim = original_x_shape[-1]

# Reshape for processing
x_quantized_2d = x_quantized.reshape(-1, feature_dim)

# Unpack 4-bit weights
unpacked_weights_shape = weights_4bit.shape
out_features = unpacked_weights_shape[0]
in_features = unpacked_weights_shape[1]

weights_unpacked = torch.empty(
(out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device
)

weights_unpacked[:, ::2] = weights_4bit >> 4
weights_unpacked[:, 1::2] = weights_4bit & 0x0F

# Convert to signed 4-bit range [-8, 7]
weights_unpacked = torch.where(
weights_unpacked > 7, weights_unpacked - 16, weights_unpacked
)

# Dequantize weights using grouped quantization
actual_in_features = in_features * 2
num_groups = actual_in_features // group_size

# Reshape weights for grouped dequantization
weights_grouped = weights_unpacked.view(out_features, num_groups, group_size)

# Expand scales and zeros to match grouped weights
scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size)
zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size)

# Dequantize: (quantized - zero_point) * scale
dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded
dq_weights = dq_weights_grouped.view(out_features, actual_in_features)

# Dequantize input (per-token)
# For per-token quantization, each token (row) has its own scale and zero_point
x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token(
x_quantized_2d,
input_scale,
input_zero_point,
-128,
127,
torch.int8,
torch.float32,
)

# Perform linear operation
out = torch.nn.functional.linear(x_dequantized, dq_weights)
out_shape = original_x_shape[:-1] + (out_features,)
return out.reshape(out_shape)


name = "linear_qta8a_qga4w"
lib.define(
f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor"
)
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################
Expand Down
Loading
Loading