9
9
import json
10
10
import logging
11
11
import os
12
+ import sys
12
13
from concurrent .futures import ThreadPoolExecutor
13
14
from typing import List , Optional
14
15
22
23
ConversationError ,
23
24
)
24
25
from speechmatics_flow .models import (
25
- ClientMessageType ,
26
- ServerMessageType ,
27
26
AudioSettings ,
27
+ ClientMessageType ,
28
+ ConnectionSettings ,
28
29
ConversationConfig ,
29
30
Interaction ,
30
- ConnectionSettings ,
31
+ PlaybackSettings ,
32
+ ServerMessageType ,
31
33
)
32
34
from speechmatics_flow .tool_function_param import ToolFunctionParam
33
35
from speechmatics_flow .utils import read_in_chunks , json_utf8
@@ -63,6 +65,7 @@ def __init__(
63
65
self .websocket = None
64
66
self .conversation_config = None
65
67
self .audio_settings = None
68
+ self .playback_settings = None
66
69
self .tools = None
67
70
68
71
self .event_handlers = {x : [] for x in ServerMessageType }
@@ -73,13 +76,15 @@ def __init__(
73
76
self .session_running = False
74
77
self .conversation_ended_wait_timeout = 5
75
78
self ._session_needs_closing = False
76
- self ._audio_buffer = None
79
+ self ._audio_buffer = bytearray ()
80
+ self ._audio_buffer_lock = asyncio .Lock ()
77
81
self ._executor = ThreadPoolExecutor ()
78
82
79
83
# The following asyncio fields are fully instantiated in
80
84
# _init_synchronization_primitives
81
85
self ._conversation_started = asyncio .Event
82
86
self ._conversation_ended = asyncio .Event
87
+ self ._response_started = asyncio .Event
83
88
# Semaphore used to ensure that we don't send too much audio data to
84
89
# the server too quickly and burst any buffers downstream.
85
90
self ._buffer_semaphore = asyncio .BoundedSemaphore
@@ -91,24 +96,34 @@ async def _init_synchronization_primitives(self):
91
96
"""
92
97
self ._conversation_started = asyncio .Event ()
93
98
self ._conversation_ended = asyncio .Event ()
99
+ self ._response_started = asyncio .Event ()
94
100
self ._buffer_semaphore = asyncio .BoundedSemaphore (
95
101
self .connection_settings .message_buffer_size
96
102
)
97
103
98
104
def _flag_conversation_started (self ):
99
105
"""
100
106
Handle a
101
- :py:attr:`models.ClientMessageType .ConversationStarted`
107
+ :py:attr:`models.ServerMessageType .ConversationStarted`
102
108
message from the server.
103
109
This updates an internal flag to mark the session started
104
110
as started meaning, AddAudio is now allowed.
105
111
"""
106
112
self ._conversation_started .set ()
107
113
114
+ def _flag_response_started (self ):
115
+ """
116
+ Handle a
117
+ :py:attr:`models.ServerMessageType.ResponseStarted`
118
+ message from the server.
119
+ This updates an internal flag to mark that the server started sending audio.
120
+ """
121
+ self ._response_started .set ()
122
+
108
123
def _flag_conversation_ended (self ):
109
124
"""
110
125
Handle a
111
- :py:attr:`models.ClientMessageType .ConversationEnded`
126
+ :py:attr:`models.ServerMessageType .ConversationEnded`
112
127
message from the server.
113
128
This updates an internal flag to mark the session ended
114
129
and server connection is closed
@@ -158,7 +173,7 @@ def _audio_received(self):
158
173
msg = {
159
174
"message" : ClientMessageType .AudioReceived ,
160
175
"seq_no" : self .server_seq_no ,
161
- "buffering" : 0.01 , # 10ms
176
+ "buffering" : self . playback_settings . buffering / 1000 ,
162
177
}
163
178
self ._call_middleware (ClientMessageType .AudioReceived , msg , False )
164
179
LOGGER .debug (msg )
@@ -169,9 +184,12 @@ async def _wait_for_conversation_ended(self):
169
184
Waits for :py:attr:`models.ClientMessageType.ConversationEnded`
170
185
message from the server.
171
186
"""
172
- await asyncio .wait_for (
173
- self ._conversation_ended .wait (), self .conversation_ended_wait_timeout
174
- )
187
+ try :
188
+ await asyncio .wait_for (
189
+ self ._conversation_ended .wait (), self .conversation_ended_wait_timeout
190
+ )
191
+ except asyncio .TimeoutError :
192
+ LOGGER .warning ("Timeout waiting for ConversationEnded message." )
175
193
176
194
async def _consumer (self , message , from_cli = False ):
177
195
"""
@@ -192,7 +210,8 @@ async def _consumer(self, message, from_cli=False):
192
210
await self .websocket .send (self ._audio_received ())
193
211
# add an audio message to local buffer only when running from cli
194
212
if from_cli :
195
- await self ._audio_buffer .put (message )
213
+ async with self ._audio_buffer_lock :
214
+ self ._audio_buffer .extend (message )
196
215
# Implicit name for all inbound binary messages.
197
216
# We must manually set it for event handler subscribed
198
217
# to `ServerMessageType.AddAudio` messages to work.
@@ -226,6 +245,13 @@ async def _consumer(self, message, from_cli=False):
226
245
227
246
if message_type == ServerMessageType .ConversationStarted :
228
247
self ._flag_conversation_started ()
248
+ if message_type == ServerMessageType .ResponseStarted :
249
+ self ._flag_response_started ()
250
+ if message_type in [
251
+ ServerMessageType .ResponseCompleted ,
252
+ ServerMessageType .ResponseInterrupted ,
253
+ ]:
254
+ self ._response_started .clear ()
229
255
elif message_type == ServerMessageType .AudioAdded :
230
256
self ._buffer_semaphore .release ()
231
257
elif message_type == ServerMessageType .ConversationEnded :
@@ -313,20 +339,31 @@ async def _producer_handler(self, interactions: List[Interaction]):
313
339
Controls the producer loop for sending messages to the server.
314
340
"""
315
341
await self ._conversation_started .wait ()
316
- if interactions [0 ].stream .name == "<stdin>" :
342
+ # Stream audio from microphone when running from the terminal and input is not piped
343
+ if (
344
+ sys .stdin .isatty ()
345
+ and hasattr (interactions [0 ].stream , "name" )
346
+ and interactions [0 ].stream .name == "<stdin>"
347
+ ):
317
348
return await self ._read_from_microphone ()
318
349
319
350
for interaction in interactions :
320
- async for message in self ._stream_producer (
321
- interaction .stream , self .audio_settings .chunk_size
322
- ):
323
- try :
324
- await self .websocket .send (message )
325
- except Exception as e :
326
- LOGGER .error (f"error sending message: { e } " )
327
- return
328
- if interaction .callback :
329
- interaction .callback (self )
351
+ try :
352
+ async for message in self ._stream_producer (
353
+ interaction .stream , self .audio_settings .chunk_size
354
+ ):
355
+ try :
356
+ await self .websocket .send (message )
357
+ except Exception as e :
358
+ LOGGER .error (f"Error sending message: { e } " )
359
+ return
360
+
361
+ if interaction .callback :
362
+ LOGGER .debug ("Executing callback for interaction." )
363
+ interaction .callback (self )
364
+
365
+ except Exception as e :
366
+ LOGGER .error (f"Error processing interaction: { e } " )
330
367
331
368
await self .websocket .send (self ._end_of_audio ())
332
369
await self ._wait_for_conversation_ended ()
@@ -339,26 +376,38 @@ async def _playback_handler(self):
339
376
stream = _pyaudio .open (
340
377
format = pyaudio .paInt16 ,
341
378
channels = 1 ,
342
- rate = self .audio_settings .sample_rate ,
343
- frames_per_buffer = 128 ,
379
+ rate = self .playback_settings .sample_rate ,
380
+ frames_per_buffer = self . playback_settings . chunk_size ,
344
381
output = True ,
345
382
)
383
+ chunk_size = self .playback_settings .chunk_size
384
+
346
385
try :
347
- while True :
348
- if self ._session_needs_closing or self ._conversation_ended .is_set ():
349
- break
386
+ while not self ._session_needs_closing or self ._conversation_ended .is_set ():
387
+ # Wait for the server to start sending audio
388
+ await self ._response_started .wait ()
389
+
390
+ # Ensure enough data is added to the buffer before starting playback
391
+ await asyncio .sleep (self .playback_settings .buffering / 1000 )
392
+
393
+ # Start playback
350
394
try :
351
- audio_message = await self ._audio_buffer .get ()
352
- stream .write (audio_message )
353
- self ._audio_buffer .task_done ()
354
- # read from buffer at a constant rate
355
- await asyncio .sleep (0.005 )
395
+ while self ._audio_buffer :
396
+ if len (self ._audio_buffer ) >= chunk_size :
397
+ async with self ._audio_buffer_lock :
398
+ audio_chunk = bytes (self ._audio_buffer [:chunk_size ])
399
+ self ._audio_buffer = self ._audio_buffer [chunk_size :]
400
+ stream .write (audio_chunk )
401
+ await asyncio .sleep (0.005 )
356
402
except Exception as e :
357
- LOGGER .error (f"Error during audio playback: { e } " )
403
+ LOGGER .error (f"Error during audio playback: { e } " , exc_info = True )
358
404
raise e
405
+
406
+ except asyncio .CancelledError :
407
+ LOGGER .info ("Playback handler cancelled." )
359
408
finally :
360
- stream .close ()
361
409
stream .stop_stream ()
410
+ stream .close ()
362
411
_pyaudio .terminate ()
363
412
364
413
def _call_middleware (self , event_name , * args ):
@@ -482,7 +531,6 @@ async def _communicate(self, interactions: List[Interaction], from_cli=False):
482
531
483
532
# Run the playback task that plays audio messages to the user when started from cli
484
533
if from_cli :
485
- self ._audio_buffer = asyncio .Queue ()
486
534
tasks .append (asyncio .create_task (self ._playback_handler ()))
487
535
488
536
(done , pending ) = await asyncio .wait (
@@ -509,6 +557,7 @@ async def run(
509
557
conversation_config : ConversationConfig = None ,
510
558
from_cli : bool = False ,
511
559
tools : Optional [List [ToolFunctionParam ]] = None ,
560
+ playback_settings : PlaybackSettings = PlaybackSettings (),
512
561
):
513
562
"""
514
563
Begin a new recognition session.
@@ -528,13 +577,18 @@ async def run(
528
577
:param tools: Optional list of tool functions.
529
578
:type tools: List[ToolFunctionParam]
530
579
580
+ :param playback_settings: Configuration for the playback stream.
581
+ :type playback_settings: models.PlaybackSettings
582
+
531
583
:raises Exception: Can raise any exception returned by the
532
584
consumer/producer tasks.
585
+
533
586
"""
534
587
self .client_seq_no = 0
535
588
self .server_seq_no = 0
536
589
self .conversation_config = conversation_config
537
590
self .audio_settings = audio_settings
591
+ self .playback_settings = playback_settings
538
592
self .tools = tools
539
593
540
594
await self ._init_synchronization_primitives ()
0 commit comments