Skip to content

Commit 75a347d

Browse files
authored
Merge pull request #38 from roryeckel/fix/backend-argument-parsing
Fix backend argument parsing by adding proper enum type converter
2 parents e7c2b4b + 6d1bab9 commit 75a347d

File tree

3 files changed

+134
-7
lines changed

3 files changed

+134
-7
lines changed

src/wyoming_openai/__main__.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from .const import __version__
2020
from .handler import OpenAIEventHandler
21+
from .utilities import create_enum_parser
2122

2223

2324
def configure_logging(level):
@@ -29,10 +30,27 @@ def configure_logging(level):
2930

3031
async def main():
3132
"""Main entry point for the Wyoming OpenAI server."""
32-
env_stt_backend = os.getenv("STT_BACKEND")
33-
env_tts_backend = os.getenv("TTS_BACKEND")
3433
parser = argparse.ArgumentParser()
3534

35+
# Create reusable enum parser for backend arguments
36+
backend_parser = create_enum_parser(OpenAIBackend)
37+
38+
stt_backend_env = os.getenv("STT_BACKEND")
39+
stt_backend_default = None
40+
if stt_backend_env:
41+
try:
42+
stt_backend_default = backend_parser(stt_backend_env)
43+
except argparse.ArgumentTypeError as exc:
44+
parser.error(str(exc))
45+
46+
tts_backend_env = os.getenv("TTS_BACKEND")
47+
tts_backend_default = None
48+
if tts_backend_env:
49+
try:
50+
tts_backend_default = backend_parser(tts_backend_env)
51+
except argparse.ArgumentTypeError as exc:
52+
parser.error(str(exc))
53+
3654
# General configuration
3755
parser.add_argument(
3856
"--uri",
@@ -71,10 +89,10 @@ async def main():
7189
)
7290
parser.add_argument(
7391
"--stt-backend",
74-
type=OpenAIBackend,
92+
type=backend_parser,
7593
required=False,
7694
choices=list(OpenAIBackend),
77-
default=OpenAIBackend[env_stt_backend] if env_stt_backend else None,
95+
default=stt_backend_default,
7896
help="Backend for speech-to-text (OPENAI, SPEACHES, KOKORO_FASTAPI, LOCALAI, or None)"
7997
)
8098
parser.add_argument(
@@ -122,10 +140,10 @@ async def main():
122140
)
123141
parser.add_argument(
124142
"--tts-backend",
125-
type=OpenAIBackend,
143+
type=backend_parser,
126144
required=False,
127145
choices=list(OpenAIBackend),
128-
default=OpenAIBackend[env_tts_backend] if env_tts_backend else None,
146+
default=tts_backend_default,
129147
help="Backend for text-to-speech (OPENAI, SPEACHES, KOKORO_FASTAPI, LOCALAI, or None)"
130148
)
131149
parser.add_argument(

src/wyoming_openai/utilities.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,54 @@
1+
import argparse
2+
from collections.abc import Callable
3+
from enum import Enum
14
from io import BytesIO
5+
from typing import TypeVar
6+
7+
E = TypeVar('E', bound=Enum)
8+
9+
10+
def create_enum_parser(enum_class: type[E], case_insensitive: bool = True) -> Callable[[str], E]:
11+
"""
12+
Create a type-safe parser function for argparse that converts strings to enum members.
13+
14+
This function generates a parser that:
15+
- Handles case-insensitive matching (optional)
16+
- Provides clear error messages listing all valid options
17+
- Raises argparse.ArgumentTypeError for invalid inputs
18+
19+
Args:
20+
enum_class: The Enum class to parse into
21+
case_insensitive: Whether to allow case-insensitive matching (default: True)
22+
23+
Returns:
24+
A callable that takes a string and returns the corresponding enum member
25+
26+
Raises:
27+
argparse.ArgumentTypeError: When the input string doesn't match any enum member
28+
29+
Example:
30+
>>> from enum import Enum
31+
>>> class Color(Enum):
32+
... RED = 1
33+
... BLUE = 2
34+
>>> parser = argparse.ArgumentParser()
35+
>>> parser.add_argument('--color', type=create_enum_parser(Color))
36+
>>> args = parser.parse_args(['--color', 'red'])
37+
>>> args.color == Color.RED
38+
True
39+
"""
40+
def parse_enum(value: str) -> E:
41+
lookup_value = value.upper() if case_insensitive else value
42+
try:
43+
return enum_class[lookup_value]
44+
except KeyError as exc:
45+
valid_options = ', '.join(member.name for member in enum_class)
46+
raise argparse.ArgumentTypeError(
47+
f"Invalid {enum_class.__name__}: '{value}'. "
48+
f"Valid options are: {valid_options}"
49+
) from exc
50+
51+
return parse_enum
252

353

454
class NamedBytesIO(BytesIO):

tests/test_utilities.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import argparse
2+
from enum import Enum
13
from io import BytesIO
24

3-
from wyoming_openai.utilities import NamedBytesIO
5+
import pytest
6+
7+
from wyoming_openai.utilities import NamedBytesIO, create_enum_parser
48

59

610
def test_named_bytes_io_name_property():
@@ -16,3 +20,58 @@ def test_named_bytes_io_inherits_bytesio():
1620
buf = NamedBytesIO(b"xyz", name="foo.wav")
1721
assert isinstance(buf, BytesIO)
1822
assert buf.read() == b"xyz"
23+
24+
25+
# Test enum for create_enum_parser tests
26+
class TestBackend(Enum):
27+
OPENAI = 1
28+
LOCAL = 2
29+
CUSTOM = 3
30+
31+
32+
def test_create_enum_parser_valid_input():
33+
"""Test that create_enum_parser successfully parses valid enum values."""
34+
parser = create_enum_parser(TestBackend)
35+
36+
assert parser("openai") == TestBackend.OPENAI
37+
assert parser("OPENAI") == TestBackend.OPENAI
38+
assert parser("local") == TestBackend.LOCAL
39+
assert parser("custom") == TestBackend.CUSTOM
40+
41+
42+
def test_create_enum_parser_invalid_input():
43+
"""Test that create_enum_parser raises ArgumentTypeError for invalid values."""
44+
parser = create_enum_parser(TestBackend)
45+
46+
with pytest.raises(argparse.ArgumentTypeError) as exc_info:
47+
parser("invalid")
48+
49+
error_msg = str(exc_info.value)
50+
assert "Invalid TestBackend" in error_msg
51+
assert "invalid" in error_msg
52+
assert "OPENAI, LOCAL, CUSTOM" in error_msg
53+
54+
55+
def test_create_enum_parser_case_sensitive():
56+
"""Test that create_enum_parser respects case_insensitive parameter."""
57+
parser = create_enum_parser(TestBackend, case_insensitive=False)
58+
59+
# Should work with exact case
60+
assert parser("OPENAI") == TestBackend.OPENAI
61+
62+
# Should fail with wrong case
63+
with pytest.raises(argparse.ArgumentTypeError):
64+
parser("openai")
65+
66+
67+
def test_create_enum_parser_with_argparse():
68+
"""Test that create_enum_parser works correctly with argparse."""
69+
parser = argparse.ArgumentParser()
70+
parser.add_argument("--backend", type=create_enum_parser(TestBackend))
71+
72+
args = parser.parse_args(["--backend", "openai"])
73+
assert args.backend == TestBackend.OPENAI
74+
75+
# Test that invalid values are caught by argparse
76+
with pytest.raises(SystemExit):
77+
parser.parse_args(["--backend", "invalid"])

0 commit comments

Comments
 (0)