Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 89 additions & 63 deletions backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ $if MODE == "per_channel":
int quant_min;
int quant_max;
};
$if MODE == "block_wise":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 blockSize; // bW, bH, bC, bN
ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN
ivec4 blockStride; // pre-computed linear strides for the block grid
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
Expand All @@ -71,68 +82,60 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);

/*
* DEQUANTIZATION SHADER (BUFFER STORAGE)
*
* This shader converts n-bit integer tensor values back to floating-point representations
* using pre-computed quantization parameters (scale and zero_point). The dequantization
* reconstructs the original floating-point values from their discrete integer representations
* with minimal precision loss.
*
* ALGORITHM:
* 1. Load quantized integer value from buffer
* 2. Apply dequantization formula: value = (qvalue - zero_point) * scale
* 3. Store reconstructed floating-point value to output buffer
*
* WORKGROUP CONFIGURATION:
* - Per-Tensor Mode:
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
* - Per-Token Mode:
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
*
* SUPPORTED CONFIGURATIONS:
* - Buffer Storage: Uses linear buffer indexing with stride-based tensor access
* - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering
* - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping
* - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0)
*
* DEQUANTIZATION FORMULA VISUALIZATION:
* For integer range [quant_min, quant_max] mapped back to [min_val, max_val]:
*
* Integer Domain: Floating Point Domain:
* quant_min ──────────────► min_val
* │ │
* │ scale = (max_val - min_val) / (quant_max - quant_min)
* │ zero_point = quant_min - round(min_val / scale)
* │ │
* quant_max ──────────────► max_val
*
* Dequantization Process:
* Input: -103 (int8)
* Step 1: qvalue - zero_point = -103 - (-128) = 25
* Step 2: result * scale = 25 * 0.1 = 2.5
* Output: 2.5 (float)
*
* PER-TENSOR DEQUANTIZATION:
* - Single scale and zero_point values for entire tensor
* - All elements use same dequantization parameters
* - Parameters passed as push constants for efficiency
* - Formula: value = (qvalue - zero_point) * scale
*
* PER-TOKEN DEQUANTIZATION:
* - Separate scale and zero_point for each token
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
* - Parameters stored in buffer arrays indexed by token_id
* - Each thread calculates its token_id from tensor coordinates
* - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id]
*
* Token ID calculation for element at tensor index (w, z, y, x):
* - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y
* - 3D tensor: token_id = z * sizes.y + y
* - 2D tensor: token_id = y
* - 1D tensor: token_id = 0
*/
Dequantization Shader (Buffer Storage)
This shader converts n-bit integer tensor values back to floating-point representations
using pre-computed quantization parameters (scale and zero_point). The dequantization
reconstructs the original floating-point values from their discrete integer representations
with minimal precision loss.
Important Considerations:
(+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension)
(+) The axis map layout is assumed to be a standard layout for scales and zero_points
(++) The scale and zero_point tensors must be implemented as buffers
Workgroup Configuration:
- dequantize_per_tensor
This mode reverses the uniform quantization applied across the entire tensor by using the
single scale and zero_point values to convert quantized integer values back to their original
floating-point representation.
(*) global_wg_size: default
(*) local_wg_size: default
- dequantize_per_token
This mode reverses the quantization applied individually to each token (or element) in the
input by using separate scale and zero_point values for each token. For a tensor of shape
[B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting
quantized values back to their original floating-point representation for each group of H
elements independently.
(*) global_wg_size: default
(*) local_wg_size: default
- dequantize_per_channel
This mode reverses the quantization applied separately to each channel of the input tensor
by using distinct scale and zero_point values for each channel. For a tensor of shape
[B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C
channels, converting quantized values back to their original floating-point representation
independently for each channel.
(*) global_wg_size: default
(*) local_wg_size: default
- dequantize_block_wise
This mode reverses the block-wise quantization applied to groups of elements by using separate
scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the
inverse affine transformation per block to convert quantized values back to their original
floating-point representation. For example, if the tensor shape is [6, 9, 4] and
blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements,
and dequantization is performed independently on each block.
(*) global_wg_size: default
(*) local_wg_size: default
Dequantization Formula:
value = (qvalue - zero_point) * scale
*/

#ifdef per_tensor

Expand Down Expand Up @@ -187,7 +190,7 @@ void dequantize_per_token() {
t_out[out_bufi] = value;
}

#else // per_channel
#elif defined(per_channel)

void dequantize_per_channel() {
const int out_bufi = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -226,6 +229,29 @@ void dequantize_per_channel() {
t_out[out_bufi] = value;
}

#else // block_wise

void dequantize_block_wise() {
const int out_bufi = int(gl_GlobalInvocationID.x);

if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T qvalue = t_in[in_bufi];

const ivec4 bcoord = out_tidx / blockSize;

const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;

const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);

t_out[out_bufi] = value;
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ dequantize_buffer:
MODE: per_token
- NAME: dequantize_per_channel_buffer
MODE: per_channel
- NAME: dequantize_block_wise_buffer
MODE: block_wise
46 changes: 45 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ $if MODE == "per_channel":
int quant_min;
int quant_max;
};
$if MODE == "block_wise":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 blockSize; // bW, bH, bC, bN
ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN
ivec4 blockStride; // pre-computed linear strides for the block grid
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec3", "t_in_limits")}
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
Expand Down Expand Up @@ -201,7 +212,7 @@ void dequantize_per_token() {
write_texel(t_out, pos, outtex);
}

#else // per_channel
#elif defined(per_channel)

void dequantize_per_channel() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -292,6 +303,39 @@ void dequantize_per_channel() {
write_texel(t_out, pos, outtex);
}

#else // block_wise

void dequantize_block_wise() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, t_in_limits)))
return;

IVEC4_T intex = load_texel(t_in, pos);
FVEC4_T outtex;

ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0);
int foldedZ = pos.z;

int C_total = numBlocks.z * blockSize.z;

[[unroll]] for (int i = 0; i < 4; ++i) {
ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total));

ivec4 bcoord = tidx / blockSize;
int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;

IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}

write_texel(t_out, pos, outtex);
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ dequantize_texture:
MODE: per_token
- NAME: dequantize_per_channel_texture3d
MODE: per_channel
- NAME: dequantize_block_wise_texture3d
MODE: block_wise
Loading
Loading