Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions backends/vulkan/_passes/fuse_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,224 @@ def fuse_into_linear_qcnw_node(
graph_module.graph.erase_node(dq_weight_node)


#########################
## linear_qta8a_qga4w ##
#########################


def matches_linear_qta8a_qga4w_pattern(
program: ExportedProgram, node: torch.fx.Node
) -> Optional[Tuple[int, int]]:
"""
Checks if the nodes surrounding a linear node matches the pattern for dynamic
activation + grouped weight quantized linear (QTA8A_QGA4W).

This pattern involves:
1. Dynamic quantization of input activations (8-bit)
2. Grouped quantization of weights (4-bit with group size)

The expected pattern from Int8DynActInt4WeightQuantizer is:
scale, zero_point = choose_qparams_affine(input)
quantized_input = quantize_affine(input, scale, zero_point)
dequantized_input = dequantize_affine(quantized_input, ...)
dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
output = linear(dequantized_input, dequantized_weight)

If the pattern matches, return (group_size, weight_bits), otherwise None.
"""
if not utils.is_linear_node(node):
return None

input_node = node.args[0]
weight_node = node.args[1]

# Type checking - ensure we have torch.fx.Node objects
if not isinstance(weight_node, torch.fx.Node):
return None
if not isinstance(input_node, torch.fx.Node):
return None

# Check if input is dequantized with dequantize_affine (from dynamic quantization)
if not (
input_node.op == "call_function"
and input_node.target is not None
and hasattr(input_node.target, "__name__")
and "dequantize_affine" in getattr(input_node.target, "__name__", "")
):
return None

# Check if weight is dequantized with dequantize_affine
if not (
weight_node.op == "call_function"
and weight_node.target is not None
and hasattr(weight_node.target, "__name__")
and "dequantize_affine" in getattr(weight_node.target, "__name__", "")
):
return None

# Get the original quantized weight and quantization parameters
if len(weight_node.args) < 4:
return None

orig_weight = weight_node.args[0]
weight_scales = weight_node.args[2]
weight_zeros = weight_node.args[3]

# Type checking
if not isinstance(orig_weight, torch.fx.Node):
return None
if not is_param_node(program, orig_weight):
return None
if not isinstance(weight_scales, torch.fx.Node):
return None
if not is_param_node(program, weight_scales):
return None
if not isinstance(weight_zeros, torch.fx.Node):
return None
if not is_param_node(program, weight_zeros):
return None

# Get tensors to analyze the quantization scheme
orig_weight_tensor = get_param_tensor(program, orig_weight)
weight_scales_tensor = get_param_tensor(program, weight_scales)
weight_zeros_tensor = get_param_tensor(program, weight_zeros)

if not isinstance(orig_weight_tensor, torch.Tensor):
return None
if not isinstance(weight_scales_tensor, torch.Tensor):
return None
if not isinstance(weight_zeros_tensor, torch.Tensor):
return None

# Check if weight is quantized to 4 bits (values should be in [-8, 7] range)
quant_min = orig_weight_tensor.min().item()
quant_max = orig_weight_tensor.max().item()

if not (quant_min >= -8 and quant_max <= 7):
return None

# Determine group size from the scales tensor shape
# For grouped quantization, scales shape should be [out_features, in_features // group_size]
out_features, in_features = orig_weight_tensor.shape

if len(weight_scales_tensor.shape) != 2:
return None

scales_out_features, num_groups = weight_scales_tensor.shape

if scales_out_features != out_features:
return None

group_size = in_features // num_groups
if in_features % group_size != 0:
return None

# Verify this is 4-bit grouped quantization
weight_bits = 4

return group_size, weight_bits


def fuse_into_linear_qta8a_qga4w_node(
program: ExportedProgram,
graph_module: torch.fx.GraphModule,
linear_node: torch.fx.Node,
group_size: int,
weight_bits: int,
) -> None:
"""
Fuse the dynamic activation + grouped weight quantized linear pattern into
a single linear_qta8a_qga4w operator.

The pattern:
dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
output = linear(dequantized_input, dequantized_weight)

Becomes:
output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
weight, group_size, weight_scales, weight_zeros)
"""
dq_input_node = linear_node.args[0]
dq_weight_node = linear_node.args[1]

assert isinstance(dq_input_node, torch.fx.Node)
assert isinstance(dq_weight_node, torch.fx.Node)

# Get the quantized input and quantization parameters from the input dequantize_affine node
# Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
quantized_input = dq_input_node.args[0]
input_scale = dq_input_node.args[2] # scale is the 3rd argument
input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None

# Get the weight and its quantization parameters from dequantize_affine
# Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
orig_weight = dq_weight_node.args[0]
weight_scales = dq_weight_node.args[2]
weight_zeros = dq_weight_node.args[3]

# Pack the 4-bit weight tensor for efficient storage
assert isinstance(orig_weight, torch.fx.Node)
orig_weight_tensor = get_param_tensor(program, orig_weight)
assert isinstance(orig_weight_tensor, torch.Tensor)
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
utils.update_program_state_dict(
program,
orig_weight.name,
packed_weight_tensor,
)
# Update the metadata to reflect the new packed shape
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)

# Create the linear_qta8a_qga4w node
with graph_module.graph.inserting_before(linear_node):
linear_qta8a_qga4w_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
(
quantized_input, # quantized input (int8)
input_scale, # mat1_scale
input_zero_point, # mat1_zero_point
orig_weight, # mat2_data (packed 4-bit weights)
group_size, # group_size (int)
weight_scales, # weight_scales
weight_zeros, # weight_zeros
),
)

# Replace the linear node with the new fused node
linear_node.replace_all_uses_with(linear_qta8a_qga4w_node)

# Erase nodes in the correct order (users first, then dependencies)
graph_module.graph.erase_node(linear_node)
graph_module.graph.erase_node(dq_weight_node)
graph_module.graph.erase_node(dq_input_node)


class FuseQuantizedOpsTransform(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.program = exported_program

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
# Check for linear_qcnw pattern (weight-only quantization)
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
if qcnw_details is not None:
qcnw_method, qcnw_nbits = qcnw_details
fuse_into_linear_qcnw_node(
self.program, graph_module, node, qcnw_method, qcnw_nbits
)
continue

# Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node)
if qta8a_qga4w_details is not None:
group_size, weight_bits = qta8a_qga4w_details
fuse_into_linear_qta8a_qga4w_node(
self.program, graph_module, node, group_size, weight_bits
)
continue

graph_module.recompile()
dead_code_elimination_pass(graph_module)
Expand Down
90 changes: 90 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,96 @@ def linear_qcs4w(
lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

########################
## linear_qta8a_qga4w ##
########################


def linear_qta8a_qga4w(
x_quantized: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weights_4bit: torch.Tensor,
group_size: int,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
):
"""
Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).

Args:
x_quantized: Already quantized input tensor (int8, per-token quantized)
input_scale: Scale for per-token quantization of input (shape: [batch_size])
input_zero_point: Zero point for per-token quantization of input (shape: [batch_size])
weights_4bit: Packed 4-bit quantized weights
group_size: Group size for weight quantization (int)
weight_scales: Per-group scales for weights
weight_zeros: Per-group zero points for weights
"""
original_x_shape = x_quantized.shape
batch_size = original_x_shape[0]
feature_dim = original_x_shape[-1]

# Reshape for processing
x_quantized_2d = x_quantized.reshape(-1, feature_dim)

# Unpack 4-bit weights
unpacked_weights_shape = weights_4bit.shape
out_features = unpacked_weights_shape[0]
in_features = unpacked_weights_shape[1]

weights_unpacked = torch.empty(
(out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device
)

weights_unpacked[:, ::2] = weights_4bit >> 4
weights_unpacked[:, 1::2] = weights_4bit & 0x0F

# Convert to signed 4-bit range [-8, 7]
weights_unpacked = torch.where(
weights_unpacked > 7, weights_unpacked - 16, weights_unpacked
)

# Dequantize weights using grouped quantization
actual_in_features = in_features * 2
num_groups = actual_in_features // group_size

# Reshape weights for grouped dequantization
weights_grouped = weights_unpacked.view(out_features, num_groups, group_size)

# Expand scales and zeros to match grouped weights
scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size)
zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size)

# Dequantize: (quantized - zero_point) * scale
dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded
dq_weights = dq_weights_grouped.view(out_features, actual_in_features)

# Dequantize input (per-token)
# For per-token quantization, each token (row) has its own scale and zero_point
x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token(
x_quantized_2d,
input_scale,
input_zero_point,
-128,
127,
torch.int8,
torch.float32,
)

# Perform linear operation
out = torch.nn.functional.linear(x_dequantized, dq_weights)
out_shape = original_x_shape[:-1] + (out_features,)
return out.reshape(out_shape)


name = "linear_qta8a_qga4w"
lib.define(
f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor"
)
lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd")
linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################
Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,12 @@ def register_int8_mm_op(features: OpFeatures):
return features


@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
@update_features(
[
exir_ops.edge.et_vk.linear_weight_int4.default,
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
]
)
def register_int4_mm_op(features: OpFeatures):
features.buffer_impl = True
features.texture_impl = TextureImplFeatures(
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ python_unittest(
"//executorch/backends/vulkan/_passes:vulkan_passes",
"//executorch/backends/vulkan/quantizer:vulkan_quantizer",
"//executorch/backends/vulkan:vulkan_preprocess",
"//pytorch/ao:torchao", # @manual
]
)

Expand Down
Loading
Loading