Skip to content

Commit 85bd1e7

Browse files
committed
Add auto for model/beam size
1 parent bfe6430 commit 85bd1e7

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 2.4.0
4+
5+
- Add "auto" for model and beam size (0) to select values based on CPU
6+
37
## 2.3.0
48

59
- Bump faster-whisper package to 1.1.0

wyoming_faster_whisper/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.3.0
1+
2.4.0

wyoming_faster_whisper/__main__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import asyncio
44
import logging
5+
import platform
56
import re
67
from functools import partial
78

@@ -21,7 +22,7 @@ async def main() -> None:
2122
parser.add_argument(
2223
"--model",
2324
required=True,
24-
help="Name of faster-whisper model to use",
25+
help="Name of faster-whisper model to use (or auto)",
2526
)
2627
parser.add_argument("--uri", required=True, help="unix:// or tcp://")
2728
parser.add_argument(
@@ -52,6 +53,7 @@ async def main() -> None:
5253
"--beam-size",
5354
type=int,
5455
default=5,
56+
help="Size of beam during decoding (0 for auto)",
5557
)
5658
parser.add_argument(
5759
"--initial-prompt",
@@ -79,6 +81,18 @@ async def main() -> None:
7981
)
8082
_LOGGER.debug(args)
8183

84+
# Automatic configuration for ARM
85+
machine = platform.machine().lower()
86+
is_arm = ("arm" in machine) or ("aarch" in machine)
87+
if args.model == "auto":
88+
args.model = "tiny-int8" if is_arm else "base-int8"
89+
_LOGGER.debug("Model automatically selected: %s", args.model)
90+
91+
if args.beam_size <= 0:
92+
args.beam_size = 1 if is_arm else 5
93+
_LOGGER.debug("Beam size automatically selected: %s", args.beam_size)
94+
95+
# Resolve model name
8296
model_name = args.model
8397
match = re.match(r"^(tiny|base|small|medium)[.-]int8$", args.model)
8498
if match:

0 commit comments

Comments
 (0)