Skip to content

Commit fd677ac

Browse files
authored
Qualcomm AI Engine Direct - GA Whisper (#12102)
Summary: - Add the unit test for Whisper - Support multi-method for Whisper - Add qnn_whisper_runner to run whisper encoder-decoder model Command: ``` python examples/qualcomm/oss_scripts/whisper/whisper.py -b build-android -s <serial> -H <host> -m SM8750 --max_seq_len 1024 ``` Performance: - SM8750: - avg encoding time: 0.037 s - avg decoding time: 0.004 s Accuracy: - Word Error Rate: 0.1941964328289032 cc: @haowhsu-quic, @cccclai , @winskuo-quic
1 parent 4ddf049 commit fd677ac

File tree

16 files changed

+1391
-19
lines changed

16 files changed

+1391
-19
lines changed

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TensorOpInfo:
5959

6060
SKIP_LIFT_OPS = {
6161
aten.full_like.default,
62+
aten.full.default,
6263
aten.arange.start_step,
6364
aten.arange.default,
6465
aten.scalar_tensor.default,

backends/qualcomm/builders/op_index_put.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ def define_node(
8888

8989
# Need to reconstruct the index tensor.
9090
# E.g., based on ScatterND Op Def in QNN Docs.
91-
# Given that
92-
# shape of input: [1, 12, 1024, 64]
93-
# indicies_node: [None, None, aten__to_copy_default_1]
94-
# shape of aten__to_copy_default_1: [1]
95-
# The shape of index tensor should be [1, 12, 1, 3]
91+
# Torch:
92+
# Given that
93+
# shape of input: [1, 12, 1024, 64]
94+
# indicies_node: [None, None, aten__to_copy_default_1]
95+
# shape of aten__to_copy_default_1: [1]
96+
# QNN:
97+
# Index tensor:
98+
# Shape: [1, 12, 1, 3]
99+
# Value: [[[0,0,x]],[[0,1,x]],...,[[0,11,x]]]
96100
# The index tensor is treated as 4-dimensional tensor of 3-tuples,
97101
# where each 3-tuple is a partial-index into input
98102
# Reference code for QNN ScatterNd:

backends/qualcomm/builders/op_linear.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import warnings
87
from typing import Dict
98

109
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -70,13 +69,6 @@ def define_node(
7069
if len(node.args) >= 3:
7170
bias_node = self.get_node(node.args[2])
7271

73-
# TODO remove this when qnn sdk support
74-
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
75-
warnings.warn(
76-
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
77-
stacklevel=1,
78-
)
79-
8072
bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
8173
bias_tensor = get_parameter(bias_node, self.edge_program)
8274
# if bias_node is getitem

backends/qualcomm/tests/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -910,9 +910,10 @@ def forward(self, x):
910910

911911

912912
class IndexCopy(torch.nn.Module):
913-
def __init__(self, skip_mutable_buffer=False):
913+
def __init__(self, copy_dim=1, skip_mutable_buffer=False):
914914
super().__init__()
915915
self.skip_mutable_buffer = skip_mutable_buffer
916+
self.copy_dim = copy_dim
916917
self.register_buffer(
917918
"k_cache",
918919
torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
@@ -921,7 +922,7 @@ def __init__(self, skip_mutable_buffer=False):
921922

922923
def forward(self, input_pos, k_val):
923924
k_out = self.k_cache
924-
k_out.index_copy_(1, input_pos, k_val)
925+
k_out.index_copy_(self.copy_dim, input_pos, k_val)
925926
return k_out + 0
926927

927928

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,19 +622,59 @@ def test_qnn_backend_index(self):
622622
def test_qnn_backend_index_copy(self):
623623
test_comb = [
624624
{
625-
QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405
625+
QCOM_MODULE: IndexCopy( # noqa: F405
626+
copy_dim=1, skip_mutable_buffer=False
627+
),
626628
QCOM_SAMPLE_INPUTS: (
627629
torch.tensor([2], dtype=torch.int64),
628630
torch.randn([1, 1, 12, 64]),
629631
),
630632
},
631633
{
632-
QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405
634+
QCOM_MODULE: IndexCopy( # noqa: F405
635+
copy_dim=2, skip_mutable_buffer=False
636+
),
637+
QCOM_SAMPLE_INPUTS: (
638+
torch.tensor([2], dtype=torch.int64),
639+
torch.randn([1, 1024, 1, 64]),
640+
),
641+
},
642+
{
643+
QCOM_MODULE: IndexCopy( # noqa: F405
644+
copy_dim=2, skip_mutable_buffer=False
645+
),
646+
QCOM_SAMPLE_INPUTS: (
647+
torch.tensor([2, 5], dtype=torch.int64),
648+
torch.randn([1, 1024, 2, 64]),
649+
),
650+
},
651+
{
652+
QCOM_MODULE: IndexCopy( # noqa: F405
653+
copy_dim=1, skip_mutable_buffer=True
654+
),
633655
QCOM_SAMPLE_INPUTS: (
634656
torch.tensor([2], dtype=torch.int64),
635657
torch.randn([1, 1, 12, 64]),
636658
),
637659
},
660+
{
661+
QCOM_MODULE: IndexCopy( # noqa: F405
662+
copy_dim=2, skip_mutable_buffer=True
663+
),
664+
QCOM_SAMPLE_INPUTS: (
665+
torch.tensor([2], dtype=torch.int64),
666+
torch.randn([1, 1024, 1, 64]),
667+
),
668+
},
669+
{
670+
QCOM_MODULE: IndexCopy( # noqa: F405
671+
copy_dim=2, skip_mutable_buffer=True
672+
),
673+
QCOM_SAMPLE_INPUTS: (
674+
torch.tensor([2, 5], dtype=torch.int64),
675+
torch.randn([1, 1024, 2, 64]),
676+
),
677+
},
638678
]
639679
for i, test in enumerate(test_comb):
640680
with self.subTest(i=i):
@@ -1907,19 +1947,59 @@ def test_qnn_backend_index(self):
19071947
def test_qnn_backend_index_copy(self):
19081948
test_comb = [
19091949
{
1910-
QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405
1950+
QCOM_MODULE: IndexCopy( # noqa: F405
1951+
copy_dim=1, skip_mutable_buffer=False
1952+
),
19111953
QCOM_SAMPLE_INPUTS: (
19121954
torch.tensor([2], dtype=torch.int64),
19131955
torch.randn([1, 1, 12, 64]),
19141956
),
19151957
},
19161958
{
1917-
QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405
1959+
QCOM_MODULE: IndexCopy( # noqa: F405
1960+
copy_dim=2, skip_mutable_buffer=False
1961+
),
1962+
QCOM_SAMPLE_INPUTS: (
1963+
torch.tensor([2], dtype=torch.int64),
1964+
torch.randn([1, 1024, 1, 64]),
1965+
),
1966+
},
1967+
{
1968+
QCOM_MODULE: IndexCopy( # noqa: F405
1969+
copy_dim=2, skip_mutable_buffer=False
1970+
),
1971+
QCOM_SAMPLE_INPUTS: (
1972+
torch.tensor([2, 5], dtype=torch.int64),
1973+
torch.randn([1, 1024, 2, 64]),
1974+
),
1975+
},
1976+
{
1977+
QCOM_MODULE: IndexCopy( # noqa: F405
1978+
copy_dim=1, skip_mutable_buffer=True
1979+
),
19181980
QCOM_SAMPLE_INPUTS: (
19191981
torch.tensor([2], dtype=torch.int64),
19201982
torch.randn([1, 1, 12, 64]),
19211983
),
19221984
},
1985+
{
1986+
QCOM_MODULE: IndexCopy( # noqa: F405
1987+
copy_dim=2, skip_mutable_buffer=True
1988+
),
1989+
QCOM_SAMPLE_INPUTS: (
1990+
torch.tensor([2], dtype=torch.int64),
1991+
torch.randn([1, 1024, 1, 64]),
1992+
),
1993+
},
1994+
{
1995+
QCOM_MODULE: IndexCopy( # noqa: F405
1996+
copy_dim=2, skip_mutable_buffer=True
1997+
),
1998+
QCOM_SAMPLE_INPUTS: (
1999+
torch.tensor([2, 5], dtype=torch.int64),
2000+
torch.randn([1, 1024, 2, 64]),
2001+
),
2002+
},
19232003
]
19242004
for i, test in enumerate(test_comb):
19252005
with self.subTest(i=i):
@@ -4909,6 +4989,39 @@ def test_swin_transformer(self):
49094989
self.assertGreaterEqual(msg["top_1"], 60)
49104990
self.assertGreaterEqual(msg["top_5"], 80)
49114991

4992+
def test_whisper(self):
4993+
if not self.required_envs():
4994+
self.skipTest("missing required envs")
4995+
4996+
cmds = [
4997+
"python",
4998+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/whisper/whisper.py",
4999+
"--artifact",
5000+
self.artifact_dir,
5001+
"--build_folder",
5002+
self.build_folder,
5003+
"--device",
5004+
self.device,
5005+
"--model",
5006+
self.model,
5007+
"--ip",
5008+
self.ip,
5009+
"--port",
5010+
str(self.port),
5011+
]
5012+
if self.host:
5013+
cmds.extend(["--host", self.host])
5014+
5015+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5016+
with Listener((self.ip, self.port)) as listener:
5017+
conn = listener.accept()
5018+
p.communicate()
5019+
msg = json.loads(conn.recv())
5020+
if "Error" in msg:
5021+
self.fail(msg["Error"])
5022+
else:
5023+
self.assertLessEqual(msg["wer"], 0.25)
5024+
49125025

49135026
class TestExampleQaihubScript(TestQNN):
49145027
def test_utils_export(self):

examples/qualcomm/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
9090
# build qnn_mimi_decoder_runner
9191
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/moshi)
9292

93+
# build qnn_whisper_runner for whisper
94+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/whisper)
95+
9396
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
9497
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)
9598

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# preprocess qnn runner src files for whisper
9+
set(_qnn_whisper_runner__srcs
10+
${CMAKE_CURRENT_LIST_DIR}/qnn_whisper_runner.cpp
11+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder.cpp
12+
${CMAKE_CURRENT_LIST_DIR}/runner/decoder.h
13+
${CMAKE_CURRENT_LIST_DIR}/runner/encoder.cpp
14+
${CMAKE_CURRENT_LIST_DIR}/runner/encoder.h
15+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
16+
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
17+
${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp
18+
)
19+
20+
# build qnn whisper runner
21+
add_executable(qnn_whisper_runner ${_qnn_whisper_runner__srcs})
22+
target_include_directories(
23+
qnn_whisper_runner PUBLIC ${_common_include_directories}
24+
${EXECUTORCH_ROOT}/extension/llm/tokenizers/include
25+
)
26+
27+
28+
target_link_libraries(
29+
qnn_whisper_runner
30+
qnn_executorch_backend
31+
executorch_core
32+
extension_data_loader
33+
extension_flat_tensor
34+
extension_module
35+
extension_tensor
36+
full_portable_ops_lib
37+
gflags
38+
tokenizers
39+
)
40+
41+
target_compile_options(
42+
qnn_whisper_runner PUBLIC ${_common_compile_options}
43+
)
44+
set_target_properties(
45+
qnn_whisper_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
46+
)

0 commit comments

Comments
 (0)