Skip to content

Commit 5999f10

Browse files
authored
[ET-VK][ez] Add support for buffer backed qparams in int4 linear + add checks for physical limits when allocating (#10233)
## Context Currently, the groupwise quantized int4 linear op implementation forces the scales and zero tensor to be a `Texture3D`. However, for i.e. transformer models that have a logit linear layer, the image extents required may exceed the maximum image extents available on the device. ## Changes * Add support for the scales and zero tensor being a `Buffer` instead of a `Texture3D` * Add checks when allocating buffers or images for tensors that the requested resource fits within the physical device limits Differential Revision: [D72662176](https://our.internmc.facebook.com/intern/diff/D72662176/)
1 parent 8fb9209 commit 5999f10

File tree

7 files changed

+66
-29
lines changed

7 files changed

+66
-29
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,26 @@ vkapi::VulkanImage allocate_image(
260260
return vkapi::VulkanImage();
261261
}
262262

263+
// TODO(ssjia): change to always check that the image extents do not exceed
264+
// physical limits. Adding the check now based on `maxImageDimension3D` will
265+
// cause some existing models to break. Anecdotally, on Adreno and
266+
// SwiftShader devices, using 3D textures that exceed `maxImageDimension3D`
267+
// appears to be ok. So we need to figure out if is it undefined behaviour
268+
// or if there's a better way to figure out what the limit is. For now, only
269+
// check during debug build so that we can detect when exceeding physical
270+
// limits could be a potential cause for model outputs to be wrong. In the
271+
// meantime, the threshold for using texture storage can be configured at
272+
// export time.
273+
#ifdef VULKAN_DEBUG
274+
uint32_t max_extent = storage_type == utils::kTexture3D
275+
? adapter_ptr->max_texture3d_dim()
276+
: adapter_ptr->max_texture2d_dim();
277+
278+
VK_CHECK_COND(
279+
image_extents[0] <= max_extent && image_extents[1] <= max_extent &&
280+
image_extents[2] <= max_extent);
281+
#endif
282+
263283
VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props);
264284

265285
return adapter_ptr->vma().create_image(
@@ -291,6 +311,8 @@ vkapi::VulkanBuffer allocate_buffer(
291311
return vkapi::VulkanBuffer();
292312
}
293313

314+
VK_CHECK_COND(numel <= context_ptr->adapter_ptr()->max_buffer_numel());
315+
294316
return adapter_ptr->vma().create_storage_buffer(
295317
element_size(dtype) * numel, allocate_memory);
296318
}

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ void main() {
109109
in_vals[r][0] = get_first(in_val_packed);
110110
in_vals[r][1] = get_second(in_val_packed);
111111
} else {
112-
in_vals[r][0] = uint8_t(254);
113-
in_vals[r][1] = uint8_t(254);
112+
in_vals[r][0] = uint8_t(0);
113+
in_vals[r][1] = uint8_t(0);
114114
}
115115
}
116116

@@ -131,6 +131,6 @@ void main() {
131131
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
132132
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
133133
$else:
134-
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
135-
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
134+
imageStore(t_qmat2, packed_pos.xy, out_tex_1);
135+
imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2);
136136
}

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
pack_int4_linear_weight_transposed_interleaved:
88
parameter_names_with_default_values:
9-
STORAGE: texture3d
9+
STORAGE: texture2d
10+
generate_variant_forall:
11+
STORAGE:
12+
- VALUE: texture2d
13+
- VALUE: buffer
1014
shader_variants:
11-
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
12-
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
13-
STORAGE: buffer
15+
- NAME: pack_int4_linear_weight_transposed_interleaved

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ layout(std430) buffer;
2121
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
2222
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
2323
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
24-
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}
24+
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)}
2525

2626
layout(push_constant) uniform restrict Block {
2727
ivec4 out_sizes;
@@ -79,13 +79,23 @@ void main() {
7979

8080
$if WEIGHT_STORAGE == "buffer":
8181
const int qmat2_stride = qmat2_sizes.x >> 2;
82+
$if PARAMS_STORAGE == "buffer":
83+
const int qparams_y_stride = out_sizes.x >> 2;
84+
const int qparams_z_stride = qparams_y_stride * 2;
8285

8386
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
84-
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
85-
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
87+
$if PARAMS_STORAGE == "buffer":
88+
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
89+
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
8690

87-
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
88-
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
91+
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
92+
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
93+
$else:
94+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
95+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
96+
97+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
98+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
8999

90100
for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
91101
const int k = block_idx * group_size + g_idx;
@@ -101,7 +111,7 @@ void main() {
101111
$else:
102112
const uvec4 packed_weight_tex = texelFetch(
103113
t_qmat2,
104-
ivec3(gl_GlobalInvocationID.x, k + comp, 0),
114+
ivec2(gl_GlobalInvocationID.x, k + comp),
105115
0);
106116

107117
const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4;

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ q_4w_linear:
99
DTYPE: float
1010
OUT_STORAGE: texture3d
1111
IN_STORAGE: texture3d
12-
WEIGHT_STORAGE: texture3d
12+
WEIGHT_STORAGE: texture2d
13+
PARAMS_STORAGE: buffer
1314
shader_variants:
14-
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
15-
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
16-
IN_STORAGE: buffer
17-
- NAME: q_4w_linear_buffer_buffer_texture3d_float
15+
- NAME: q_4w_linear_texture3d_texture3d_texture2d_float
16+
- NAME: q_4w_linear_buffer_buffer_texture2d_float
1817
OUT_STORAGE: buffer
1918
IN_STORAGE: buffer
2019
- NAME: q_4w_linear_buffer_buffer_buffer_float

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
8383
const int64_t N = qmat2_orig_sizes.at(ndim - 2);
8484
const int64_t N_div2 = N / int64_t(2);
8585

86-
utils::StorageType storage_type = utils::kTexture3D;
87-
utils::uvec3 max_extents =
88-
graph.context()->adapter_ptr()->max_texture_extents();
89-
if (N_div2 > max_extents[0] * 4 || K > max_extents[1]) {
86+
utils::StorageType storage_type = utils::kTexture2D;
87+
uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim();
88+
if (N_div2 > max_extent * 4 || K > max_extent) {
9089
storage_type = utils::kBuffer;
9190
}
9291

@@ -133,7 +132,7 @@ void add_q_4w_linear_node(
133132
prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data);
134133

135134
ValueRef scales_and_zeros = prepack_standard_hw_transposed(
136-
graph, scales_and_zeros_data, utils::kTexture3D, utils::kWidthPacked);
135+
graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked);
137136

138137
std::string kernel_name = "q_4w_linear";
139138
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,16 @@ class Adapter final {
211211
return physical_device_.min_ubo_alignment;
212212
}
213213

214-
inline utils::uvec3 max_texture_extents() const {
215-
return {
216-
physical_device_.properties.limits.maxImageDimension1D,
217-
physical_device_.properties.limits.maxImageDimension2D,
218-
physical_device_.properties.limits.maxImageDimension3D};
214+
inline uint32_t max_texture2d_dim() const {
215+
return physical_device_.properties.limits.maxImageDimension2D;
216+
}
217+
218+
inline uint32_t max_texture3d_dim() const {
219+
return physical_device_.properties.limits.maxImageDimension3D;
220+
}
221+
222+
inline uint32_t max_buffer_numel() const {
223+
return physical_device_.properties.limits.maxStorageBufferRange;
219224
}
220225

221226
// Command Buffer Submission

0 commit comments

Comments
 (0)