@@ -62,6 +62,11 @@ def enqueue_stream(self) -> AsyncGenerator[NP_AUDIO, None]:
62
62
"Method 'enqueue_stream' must be implemented by subclass"
63
63
)
64
64
65
+ def get_sample_rate (self ) -> int :
66
+ raise NotImplementedError (
67
+ "Method 'get_sample_rate' must be implemented by subclass"
68
+ )
69
+
65
70
def set_current_request (self , request : Request ):
66
71
assert self .current_request is None , "current_request has been set"
67
72
assert isinstance (
@@ -93,7 +98,9 @@ def get_encoder(self) -> StreamEncoder:
93
98
else :
94
99
raise ValueError (f"Unsupported audio format: { format } " )
95
100
101
+ encoder .set_header (sample_rate = self .get_sample_rate ())
96
102
encoder .open (bitrate = bitrate , acodec = acodec )
103
+ encoder .write_header_data ()
97
104
98
105
return encoder
99
106
@@ -104,9 +111,6 @@ async def enqueue_to_stream(self) -> AsyncGenerator[bytes, None]:
104
111
105
112
chunk_data = bytes ()
106
113
async for sample_rate , audio_data in self .enqueue_stream ():
107
- encoder .set_header (
108
- sample_rate = sample_rate , sample_width = audio_data .dtype .itemsize
109
- )
110
114
audio_bytes = covert_to_s16le (audio_data = audio_data )
111
115
112
116
logger .debug (f"write audio_bytes len: { len (audio_bytes )} " )
@@ -153,9 +157,6 @@ async def enqueue_to_stream_join(self) -> AsyncGenerator[bytes, None]:
153
157
encoder = self .get_encoder ()
154
158
chunk_data = bytes ()
155
159
async for sample_rate , audio_data in self .enqueue_stream ():
156
- encoder .set_header (
157
- sample_rate = sample_rate , sample_width = audio_data .dtype .itemsize
158
- )
159
160
audio_bytes = covert_to_s16le (audio_data = audio_data )
160
161
encoder .write (audio_bytes )
161
162
@@ -166,26 +167,30 @@ async def enqueue_to_stream_join(self) -> AsyncGenerator[bytes, None]:
166
167
167
168
encoder .terminate ()
168
169
170
+ async def _enqueue_to_bytes (self ) -> bytes :
171
+ """
172
+ 为了测试拆分的函数
173
+ 这个函数不依赖 current_request 状态
174
+ """
175
+ encoder = self .get_encoder ()
176
+ buffer = bytes ()
177
+ try :
178
+ sample_rate , audio_data = await self .enqueue ()
179
+ audio_bytes = covert_to_s16le (audio_data = audio_data )
180
+ encoder .write (audio_bytes )
181
+ encoder .close ()
182
+ buffer = encoder .read_all ()
183
+ finally :
184
+ encoder .terminate ()
185
+ return buffer
186
+
169
187
async def enqueue_to_bytes (self ) -> bytes :
170
188
if self .current_request is None :
171
189
raise ValueError ("current_request is not set" )
172
190
173
- encoder = self .get_encoder ()
174
-
175
191
# NOTE: 这里的逻辑类似 goto
176
192
async with cancel_on_disconnect (self .current_request ):
177
- try :
178
- sample_rate , audio_data = await self .enqueue ()
179
- audio_bytes = covert_to_s16le (audio_data = audio_data )
180
- encoder .set_header (
181
- sample_rate = sample_rate , sample_width = audio_data .dtype .itemsize
182
- )
183
- encoder .write (audio_bytes )
184
- encoder .close ()
185
- buffer = encoder .read_all ()
186
- finally :
187
- encoder .terminate ()
188
- return buffer
193
+ return self ._enqueue_to_bytes ()
189
194
190
195
logger .debug (f"disconnected" )
191
196
self .interrupt ()
0 commit comments