diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index fe9b82f91d..2d12197fb5 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -21,6 +21,7 @@ #include #include +#include namespace torch { namespace executor { @@ -180,3 +181,387 @@ void check_dequantize_args( ")"); } } + +// +// Reference Implementation +// + +/* + * Reference implementation of dequantize_per_tensor + */ +at::Tensor dequantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Dequantize the input tensor + at::Tensor flat_input = input.flatten(); + at::Tensor flat_out = out.flatten(); + + // Store casted values to avoid repeated casting + const int32_t zero_point_int32 = static_cast(zero_point); + const float scale_float = static_cast(scale); + + for (int i = 0; i < flat_input.numel(); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + flat_out[i] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + flat_out[i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + flat_out[i] = static_cast(dequantized_value); + } + } + + return out.reshape(input.sizes()); +} + +// Forward declaration of implementation functions +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = dequantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + const ValueRef r_scale = graph.add_scalar(scale); + const ValueRef r_zero_point = graph.add_scalar(zero_point); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_tensor +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_float) { + test_reference_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int8_to_float) { + test_reference_dequantize_per_tensor( + {3, 4, 5}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_float) { + test_reference_dequantize_per_tensor( + {4, 6, 2}, // input sizes + 0.2, // scale + 2, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_half) { + test_reference_dequantize_per_tensor( + {7, 4}, // input sizes + 0.1, // scale + 10, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype (uint8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_half) { + test_reference_dequantize_per_tensor( + {2, 6, 5}, // input sizes + 0.3, // scale + -10, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +}