Skip to content

[ET-VK][Ops] common test utils for converting aten types to vulkan types #11671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/ahmtox/7/orig
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
Expand Down Expand Up @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w(
ASSERT_TRUE(at::allclose(out, out_ref));
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_linear_qga4w_impl(
const int B,
const int M,
Expand Down
22 changes: 2 additions & 20 deletions backends/vulkan/test/op_tests/rotary_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include "test_utils.h"

#include <cassert>

//
Expand Down Expand Up @@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
// Test functions
//

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_reference(
const int n_heads = 4,
const int n_kv_heads = 2,
Expand Down
20 changes: 2 additions & 18 deletions backends/vulkan/test/op_tests/sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_sdpa.h>

#include "test_utils.h"

#include <cassert>
#include <iostream>

Expand Down Expand Up @@ -261,24 +263,6 @@ void test_reference_sdpa(
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kFloat:
return vkapi::kFloat;
case c10::kHalf:
return vkapi::kHalf;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kInt;
case c10::kChar:
return vkapi::kChar;
default:
VK_THROW("Unsupported at::ScalarType!");
}
}

void test_vulkan_sdpa(
const int start_input_pos,
const int base_sequence_len,
Expand Down
37 changes: 35 additions & 2 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False):
platforms = get_platforms(),
)

runtime.cxx_library(
name = "test_utils",
srcs = [
"test_utils.cpp",
],
headers = [
"test_utils.h",
],
exported_headers = [
"test_utils.h",
],
deps = [
"//executorch/backends/vulkan:vulkan_graph_runtime",
"//executorch/runtime/core/exec_aten:lib",
runtime.external_dep_location("libtorch"),
],
visibility = [
"//executorch/backends/vulkan/test/op_tests/...",
"@EXECUTORCH_CLIENTS",
],
)

define_test_targets(
"compute_graph_op_tests",
src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]"
Expand All @@ -150,9 +172,20 @@ def define_common_targets(is_fbcode = False):
define_test_targets(
"sdpa_test",
extra_deps = [
":test_utils",
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/extension/tensor:tensor",
]
)
define_test_targets("linear_weight_int4_test")
define_test_targets("rotary_embedding_test")
define_test_targets(
"linear_weight_int4_test",
extra_deps = [
":test_utils",
]
)
define_test_targets(
"rotary_embedding_test",
extra_deps = [
":test_utils",
]
)
114 changes: 114 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "test_utils.h"

#include <stdexcept>

executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype) {
using ScalarType = executorch::aten::ScalarType;
switch (dtype) {
case at::kByte:
return ScalarType::Byte;
case at::kChar:
return ScalarType::Char;
case at::kShort:
return ScalarType::Short;
case at::kInt:
return ScalarType::Int;
case at::kLong:
return ScalarType::Long;
case at::kHalf:
return ScalarType::Half;
case at::kFloat:
return ScalarType::Float;
case at::kDouble:
return ScalarType::Double;
default:
throw std::runtime_error("Unsupported dtype");
}
}

std::string scalar_type_name(c10::ScalarType dtype) {
switch (dtype) {
case c10::kLong:
return "c10::kLong";
case c10::kShort:
return "c10::kShort";
case c10::kComplexHalf:
return "c10::kComplexHalf";
case c10::kComplexFloat:
return "c10::kComplexFloat";
case c10::kComplexDouble:
return "c10::kComplexDouble";
case c10::kBool:
return "c10::kBool";
case c10::kQInt8:
return "c10::kQInt8";
case c10::kQUInt8:
return "c10::kQUInt8";
case c10::kQInt32:
return "c10::kQInt32";
case c10::kBFloat16:
return "c10::kBFloat16";
case c10::kQUInt4x2:
return "c10::kQUInt4x2";
case c10::kQUInt2x4:
return "c10::kQUInt2x4";
case c10::kFloat:
return "c10::kFloat";
case c10::kHalf:
return "c10::kHalf";
case c10::kInt:
return "c10::kInt";
case c10::kChar:
return "c10::kChar";
case c10::kByte:
return "c10::kByte";
case c10::kDouble:
return "c10::kDouble";
case c10::kUInt16:
return "c10::kUInt16";
case c10::kBits16:
return "c10::kBits16";
default:
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
}
}

vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
using namespace vkcompute;
switch (at_scalartype) {
case c10::kHalf:
return vkapi::kHalf;
case c10::kFloat:
return vkapi::kFloat;
case c10::kDouble:
return vkapi::kDouble;
case c10::kInt:
return vkapi::kInt;
case c10::kLong:
return vkapi::kLong;
case c10::kChar:
return vkapi::kChar;
case c10::kByte:
return vkapi::kByte;
case c10::kShort:
return vkapi::kShort;
case c10::kUInt16:
return vkapi::kUInt16;
default:
VK_THROW(
"Unsupported at::ScalarType: ",
scalar_type_name(at_scalartype),
" (",
static_cast<int>(at_scalartype),
")");
}
}
32 changes: 32 additions & 0 deletions backends/vulkan/test/op_tests/test_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <string>

#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <executorch/backends/vulkan/runtime/api/api.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

/**
* Convert at::ScalarType to executorch::ScalarType
*/
executorch::aten::ScalarType at_scalartype_to_et_scalartype(
at::ScalarType dtype);

/**
* Get the string name of a c10::ScalarType for better error messages
*/
std::string scalar_type_name(c10::ScalarType dtype);

/**
* Convert c10::ScalarType to vkcompute::vkapi::ScalarType
*/
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype);
Loading