Skip to content

Commit 35a8a25

Browse files
authored
feat: Add profanity filter (#69)
* refactor: use Pydantic model for Microsoft STT configuration * feat: add profanity filter configuration to Microsoft STT
1 parent 38ba1bf commit 35a8a25

File tree

6 files changed

+107
-18
lines changed

6 files changed

+107
-18
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
wyoming==1.6.0
22
azure-cognitiveservices-speech==1.42.0
3-
ruff
3+
ruff
4+
pydantic>=2,<3

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Fixtures for tests."""
22

3-
from types import SimpleNamespace
3+
from wyoming_microsoft_stt import SpeechConfig
44
import pytest
55
import os
66

77

88
@pytest.fixture
99
def microsoft_stt_args():
1010
"""Return MicrosoftSTT instance."""
11-
args = SimpleNamespace(
11+
args = SpeechConfig(
1212
subscription_key=os.environ.get("SPEECH_KEY"),
1313
service_region=os.environ.get("SPEECH_REGION"),
1414
)

tests/test_microsoft_stt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ def test_transcribe(microsoft_stt_args):
1818

1919
result = microsoft_stt.transcribe(filename, language)
2020
assert "hello world" in result.lower()
21+
22+
23+
def test_set_profanity(microsoft_stt_args):
24+
"""Test set_profanity."""
25+
microsoft_stt = MicrosoftSTT(microsoft_stt_args)
26+
assert microsoft_stt.speech_config is not None
27+
28+
profanity = "masked"
29+
microsoft_stt.set_profanity(profanity)
30+
# There is currently no way to check the set profanity level

wyoming_microsoft_stt/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
11
"""Wyoming server for Microsoft STT."""
2+
3+
from typing import Literal
4+
from pydantic import BaseModel
5+
6+
7+
class SpeechConfig(BaseModel):
8+
"""Speech configuration."""
9+
10+
subscription_key: str
11+
service_region: str
12+
profanity: Literal["off", "masked", "removed"] = "masked"
13+
language: str = "en-US"

wyoming_microsoft_stt/__main__.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,83 @@
1414
from .microsoft_stt import MicrosoftSTT
1515
from .handler import MicrosoftEventHandler
1616
from .version import __version__
17+
from . import SpeechConfig
1718

1819
_LOGGER = logging.getLogger(__name__)
1920

2021
stop_event = asyncio.Event()
2122

23+
2224
def handle_stop_signal(*args):
2325
"""Handle shutdown signal and set the stop event."""
2426
_LOGGER.info("Received stop signal. Shutting down...")
2527
stop_event.set()
2628

29+
2730
def parse_arguments():
2831
"""Parse command-line arguments."""
2932
parser = argparse.ArgumentParser()
30-
parser.add_argument("--service-region", default=os.getenv("AZURE_SERVICE_REGION"), help="Microsoft Azure region (e.g., westus2)")
31-
parser.add_argument("--subscription-key", default=os.getenv("AZURE_SUBSCRIPTION_KEY"), help="Microsoft Azure subscription key")
32-
parser.add_argument("--uri", default="tcp://0.0.0.0:10300", help="unix:// or tcp://")
33-
parser.add_argument("--download-dir", default="/tmp/", help="Directory to download languages.json into (default: /tmp/)")
34-
parser.add_argument("--language", default="en-US", help="Default language to set for transcription")
35-
parser.add_argument("--update-languages", action="store_true", help="Download latest languages.json during startup")
33+
parser.add_argument(
34+
"--service-region",
35+
default=os.getenv("AZURE_SERVICE_REGION"),
36+
help="Microsoft Azure region (e.g., westus2)",
37+
)
38+
parser.add_argument(
39+
"--subscription-key",
40+
default=os.getenv("AZURE_SUBSCRIPTION_KEY"),
41+
help="Microsoft Azure subscription key",
42+
)
43+
parser.add_argument(
44+
"--uri", default="tcp://0.0.0.0:10300", help="unix:// or tcp://"
45+
)
46+
parser.add_argument(
47+
"--download-dir",
48+
default="/tmp/",
49+
help="Directory to download languages.json into (default: /tmp/)",
50+
)
51+
parser.add_argument(
52+
"--language", default="en-US", help="Default language to set for transcription"
53+
)
54+
parser.add_argument(
55+
"--update-languages",
56+
action="store_true",
57+
help="Download latest languages.json during startup",
58+
)
59+
parser.add_argument(
60+
"--profanity",
61+
default="masked",
62+
choices=["masked", "removed", "raw"],
63+
help="Profanity setting for speech recognition",
64+
)
3665
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
3766
return parser.parse_args()
3867

68+
3969
def validate_args(args):
4070
"""Validate command-line arguments."""
4171
if not args.service_region or not args.subscription_key:
42-
raise ValueError("Both --service-region and --subscription-key must be provided either as command-line arguments or environment variables.")
72+
raise ValueError(
73+
"Both --service-region and --subscription-key must be provided either as command-line arguments or environment variables."
74+
)
4375
# Reinstate key validation with more flexibility to accommodate complex keys
44-
if not re.match(r'^[A-Za-z0-9\-_]{40,}$', args.subscription_key):
45-
_LOGGER.warning("The subscription key does not match the expected format but will attempt to initialize.")
76+
if not re.match(r"^[A-Za-z0-9\-_]{40,}$", args.subscription_key):
77+
_LOGGER.warning(
78+
"The subscription key does not match the expected format but will attempt to initialize."
79+
)
80+
4681

4782
async def main() -> None:
4883
"""Start Wyoming Microsoft STT server."""
4984
args = parse_arguments()
5085
validate_args(args)
5186

87+
speech_config = SpeechConfig(
88+
subscription_key=args.subscription_key,
89+
service_region=args.service_region,
90+
profanity=args.profanity,
91+
language=args.language,
92+
)
93+
5294
# Set up logging
5395
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
5496
_LOGGER.debug("Arguments parsed successfully.")
@@ -98,7 +140,7 @@ async def main() -> None:
98140
# Load Microsoft STT model
99141
try:
100142
_LOGGER.debug("Loading Microsoft STT")
101-
stt_model = MicrosoftSTT(args)
143+
stt_model = MicrosoftSTT(speech_config)
102144
_LOGGER.info("Microsoft STT model loaded successfully.")
103145
except Exception as e:
104146
_LOGGER.error(f"Failed to load Microsoft STT model: {e}")
@@ -121,6 +163,7 @@ async def main() -> None:
121163
except Exception as e:
122164
_LOGGER.error(f"An error occurred while running the server: {e}")
123165

166+
124167
if __name__ == "__main__":
125168
# Set up signal handling for graceful shutdown
126169
signal.signal(signal.SIGTERM, handle_stop_signal)

wyoming_microsoft_stt/microsoft_stt.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import azure.cognitiveservices.speech as speechsdk # noqa: D100
22
import logging
3+
from . import SpeechConfig
34

45
_LOGGER = logging.getLogger(__name__)
56

67

78
class MicrosoftSTT:
89
"""Class to handle Microsoft STT."""
910

10-
def __init__(self, args) -> None:
11+
def __init__(self, speechconfig: SpeechConfig) -> None:
1112
"""Initialize."""
12-
self.args = args
13+
self.args = speechconfig
14+
1315
try:
1416
# Initialize the speech configuration with the provided subscription key and region
1517
self.speech_config = speechsdk.SpeechConfig(
16-
subscription=args.subscription_key, region=args.service_region
18+
subscription=self.args.subscription_key, region=self.args.service_region
1719
)
1820
_LOGGER.info("Microsoft SpeechConfig initialized successfully.")
1921
except Exception as e:
2022
_LOGGER.error(f"Failed to initialize Microsoft SpeechConfig: {e}")
2123
raise
2224

25+
self.set_profanity(self.args.profanity)
26+
2327
def transcribe(self, filename: str, language=None):
2428
"""Transcribe a file."""
2529
# Use the default language from args if no language is provided
@@ -48,10 +52,29 @@ def transcribe(self, filename: str, language=None):
4852
return ""
4953
elif result.reason == speechsdk.ResultReason.Canceled:
5054
cancellation_details = result.cancellation_details
51-
_LOGGER.warning(f"Speech Recognition canceled: {cancellation_details.reason}")
55+
_LOGGER.warning(
56+
f"Speech Recognition canceled: {cancellation_details.reason}"
57+
)
5258
if cancellation_details.reason == speechsdk.CancellationReason.Error:
53-
_LOGGER.error(f"Error details: {cancellation_details.error_details}")
59+
_LOGGER.error(
60+
f"Error details: {cancellation_details.error_details}"
61+
)
5462
return ""
5563
except Exception as e:
5664
_LOGGER.error(f"Failed to transcribe audio file {filename}: {e}")
5765
return ""
66+
67+
def set_profanity(self, profanity: str):
68+
"""Set the profanity filter level."""
69+
if profanity == "off":
70+
profanity_level = speechsdk.ProfanityOption.Raw
71+
elif profanity == "masked":
72+
profanity_level = speechsdk.ProfanityOption.Masked
73+
elif profanity == "removed":
74+
profanity_level = speechsdk.ProfanityOption.Removed
75+
else:
76+
_LOGGER.error(f"Invalid profanity level: {profanity}")
77+
return
78+
79+
self.speech_config.set_profanity(profanity_level)
80+
_LOGGER.debug(f"Profanity filter set to {profanity}")

0 commit comments

Comments
 (0)