33import time
44from io import BytesIO
55from pathlib import Path
6- from typing import BinaryIO , Dict , List , Optional , Union , Sequence , Any
7-
6+ from typing import (
7+ BinaryIO ,
8+ Dict ,
9+ List ,
10+ Optional ,
11+ Union ,
12+ Sequence ,
13+ Any ,
14+ Callable ,
15+ Awaitable ,
16+ )
817import openai
918from aiogram .types import Message as AiogramMessage
1019from loguru import logger
@@ -190,6 +199,9 @@ async def track_costs(
190199
191200 # Estimate cost
192201 if operation == "transcription" :
202+ assert (
203+ audio_duration_seconds is not None
204+ ), "audio_duration_seconds is required for transcription cost estimation"
193205 estimated_cost = estimate_whisper_cost (
194206 file_size_mb or (input_length / (1024 * 1024 )),
195207 model ,
@@ -239,22 +251,29 @@ async def track_costs(
239251 message_id = message_id ,
240252 )
241253
242- async def write_accumulated_costs (self , username : str , operation_type : Optional [str ] = None ) -> None :
254+ async def write_accumulated_costs (
255+ self , username : str , operation_type : Optional [str ] = None
256+ ) -> None :
243257 """
244258 Write accumulated costs to the database and clear the accumulator.
245259
246260 Args:
247261 username: Username for cost tracking
248262 operation_type: Optional operation type to filter costs (e.g., 'transcription', 'formatting')
249263 """
250- if username not in self ._accumulated_costs or not self ._accumulated_costs [username ]:
264+ if (
265+ username not in self ._accumulated_costs
266+ or not self ._accumulated_costs [username ]
267+ ):
251268 return
252269
253270 costs_to_write = self ._accumulated_costs [username ]
254271
255272 # Filter by operation type if specified
256273 if operation_type :
257- costs_to_write = [cost for cost in costs_to_write if cost ["operation" ] == operation_type ]
274+ costs_to_write = [
275+ cost for cost in costs_to_write if cost ["operation" ] == operation_type
276+ ]
258277
259278 if not costs_to_write :
260279 return
@@ -275,9 +294,15 @@ async def write_accumulated_costs(self, username: str, operation_type: Optional[
275294 }
276295
277296 operation_totals [operation ]["cost" ] += cost_data ["cost" ]
278- operation_totals [operation ]["input_length" ] += cost_data ["usage" ].get ("input_length" , 0 )
279- operation_totals [operation ]["output_length" ] += cost_data ["usage" ].get ("output_length" , 0 )
280- operation_totals [operation ]["processing_time" ] += cost_data ["usage" ].get ("processing_time" , 0.0 )
297+ operation_totals [operation ]["input_length" ] += cost_data ["usage" ].get (
298+ "input_length" , 0
299+ )
300+ operation_totals [operation ]["output_length" ] += cost_data ["usage" ].get (
301+ "output_length" , 0
302+ )
303+ operation_totals [operation ]["processing_time" ] += cost_data ["usage" ].get (
304+ "processing_time" , 0.0
305+ )
281306 operation_totals [operation ]["count" ] += 1
282307
283308 # Write aggregated costs to database
@@ -301,7 +326,8 @@ async def write_accumulated_costs(self, username: str, operation_type: Optional[
301326 # Clear the accumulated costs for this user and operation
302327 if operation_type :
303328 self ._accumulated_costs [username ] = [
304- cost for cost in self ._accumulated_costs [username ]
329+ cost
330+ for cost in self ._accumulated_costs [username ]
305331 if cost ["operation" ] != operation_type
306332 ]
307333 else :
@@ -427,7 +453,10 @@ async def get_total_cost(
427453
428454 # Main method
429455 async def process_message (
430- self , message : AiogramMessage , whisper_model : Optional [str ] = None
456+ self ,
457+ message : AiogramMessage ,
458+ whisper_model : Optional [str ] = None ,
459+ status_callback : Optional [Callable [[str ], Awaitable [None ]]] = None ,
431460 ) -> str :
432461 """
433462 [x] process_message
@@ -465,18 +494,25 @@ async def process_message(
465494 logger .info (f"User { username } submitted a file for processing: { file_info } " )
466495 await self .log_event ("file_submission" , username , file_info )
467496
468- media_file = await self .download_attachment (message )
497+ media_file = await self .download_attachment (
498+ message , status_callback = status_callback
499+ )
469500
470- parts = await self .prepare_parts (media_file )
501+ parts = await self .prepare_parts (media_file , status_callback = status_callback )
471502
472503 # Store message_id per user for async safety
473504 self ._user_message_ids [username ] = message_id
474505
475506 chunks = await self .process_parts (
476- parts , username = username , whisper_model = whisper_model
507+ parts ,
508+ username = username ,
509+ whisper_model = whisper_model ,
510+ status_callback = status_callback ,
477511 )
478512
479- chunks = await self .format_chunks_with_llm (chunks , username = username )
513+ chunks = await self .format_chunks_with_llm (
514+ chunks , username = username , status_callback = status_callback
515+ )
480516
481517 result = merge_all_chunks (chunks )
482518
@@ -514,8 +550,12 @@ def _get_file_type(self, message: AiogramMessage) -> str:
514550 return "unknown"
515551
516552 async def download_attachment (
517- self , message : AiogramMessage
553+ self ,
554+ message : AiogramMessage ,
555+ status_callback : Optional [Callable [[str ], Awaitable [None ]]] = None ,
518556 ) -> Union [BinaryIO , Path ]:
557+ if status_callback is not None :
558+ await status_callback ("Downloading media file..." )
519559 return await download_file (
520560 message = message ,
521561 target_dir = self .config .downloads_dir ,
@@ -528,8 +568,16 @@ async def download_attachment(
528568 )
529569
530570 # region prepare_parts
531- async def prepare_parts (self , media_file : Union [BinaryIO , Path ]) -> Sequence [Audio ]:
571+ async def prepare_parts (
572+ self ,
573+ media_file : Union [BinaryIO , Path ],
574+ status_callback : Optional [Callable [[str ], Awaitable [None ]]] = None ,
575+ ) -> Sequence [Audio ]:
532576 if isinstance (media_file , Path ):
577+ if status_callback is not None :
578+ await status_callback (
579+ "Preparing audio - converting to mp3 and cutting into manageable parts..."
580+ )
533581 # process file on disk - with
534582 parts = await self .process_file_on_disk (media_file )
535583 if self .config .cleanup_downloads :
@@ -580,6 +628,7 @@ async def process_parts(
580628 parts : Sequence [Audio ],
581629 username : Optional [str ] = None ,
582630 whisper_model : Optional [str ] = None ,
631+ status_callback : Optional [Callable [[str ], Awaitable [None ]]] = None ,
583632 ) -> List [str ]:
584633 """
585634 Process multiple audio parts.
@@ -596,13 +645,18 @@ async def process_parts(
596645 logger .info (f"Processing { len (parts )} audio parts" )
597646 start_time = time .time ()
598647
648+ if status_callback is not None :
649+ await status_callback (f"Parsing audio - { len (parts )} parts..." )
650+
599651 # this has to be done one by one, do NOT parallelize.
600652 for i , part in enumerate (parts ):
601653 logger .info (f"Processing part { i + 1 } /{ len (parts )} " )
602654 # todo: make sure memory is freed after each part.
603655 chunks += await self .process_part (
604656 part , username = username , whisper_model = whisper_model
605657 )
658+ if status_callback is not None :
659+ await status_callback (f"Part { i + 1 } /{ len (parts )} done" )
606660 if isinstance (part , Path ):
607661 if self .config .cleanup_downloads :
608662 part .unlink (missing_ok = True )
@@ -666,9 +720,14 @@ async def process_part(
666720 )
667721
668722 audio_chunks = split_audio (
669- audio , period = period , buffer = self .config .overlap_duration , return_as_files = False
723+ audio ,
724+ period = period ,
725+ buffer = self .config .overlap_duration ,
726+ return_as_files = False ,
727+ )
728+ assert all (
729+ isinstance (audio_chunk , AudioSegment ) for audio_chunk in audio_chunks
670730 )
671- assert all (isinstance (audio_chunk , AudioSegment ) for audio_chunk in audio_chunks )
672731
673732 logger .info (
674733 f"Split audio into { len (audio_chunks )} chunks with period { period / 1000 :.2f} seconds"
@@ -770,7 +829,7 @@ async def parse_audio_chunk(
770829 processing_time = processing_time ,
771830 file_name = audio_file .name ,
772831 audio_duration_seconds = audio_duration_sec ,
773- write_to_db = False
832+ write_to_db = False ,
774833 )
775834
776835 return transcription
@@ -802,11 +861,11 @@ async def parse_audio_chunks(
802861 # Create tasks for all chunks
803862 tasks = [
804863 self .parse_audio_chunk (
805- chunk ,
806- model_name = model_name ,
807- language = language ,
864+ chunk ,
865+ model_name = model_name ,
866+ language = language ,
808867 username = username ,
809- chunk_index = i
868+ chunk_index = i ,
810869 )
811870 for i , chunk in enumerate (audio_chunks )
812871 ]
@@ -830,7 +889,10 @@ async def parse_audio_chunks(
830889 # region format_chunks_with_llm
831890
832891 async def format_chunks_with_llm (
833- self , chunks : List [str ], username : str
892+ self ,
893+ chunks : List [str ],
894+ username : str ,
895+ status_callback : Optional [Callable [[str ], Awaitable [None ]]] = None ,
834896 ) -> List [str ]:
835897 """
836898 Format multiple text chunks using LLM.
@@ -852,6 +914,9 @@ async def format_chunks_with_llm(
852914 )
853915 return chunks
854916
917+ if status_callback is not None :
918+ await status_callback ("Formatting punctuation and capitalization..." )
919+
855920 logger .info (
856921 f"Formatting { len (chunks )} text chunks with LLM (total: { total_length } chars)"
857922 )
@@ -914,14 +979,16 @@ async def format_chunk(self, chunk: str, username: str) -> str:
914979 input_data = chunk ,
915980 output_data = formatted_text ,
916981 processing_time = processing_time ,
917- write_to_db = False
982+ write_to_db = False ,
918983 )
919984
920985 return formatted_text
921986
922987 # endregion format_chunks_with_llm
923988
924- async def create_summary (self , transcript : str , username : str , message_id : Optional [int ] = None ) -> str :
989+ async def create_summary (
990+ self , transcript : str , username : str , message_id : Optional [int ] = None
991+ ) -> str :
925992 """
926993 Create a summary of the transcript using the configured model.
927994
@@ -975,7 +1042,7 @@ async def create_summary(self, transcript: str, username: str, message_id: Optio
9751042 input_data = transcript ,
9761043 output_data = summary ,
9771044 processing_time = processing_time ,
978- write_to_db = True
1045+ write_to_db = True ,
9791046 )
9801047
9811048 # Write accumulated summary costs to the database
@@ -984,7 +1051,9 @@ async def create_summary(self, transcript: str, username: str, message_id: Optio
9841051
9851052 return summary
9861053
987- async def chat_about_transcript (self , full_prompt : str , username : str , model : str = None ) -> str :
1054+ async def chat_about_transcript (
1055+ self , full_prompt : str , username : str , model : str = None
1056+ ) -> str :
9881057 """
9891058 Chat about the transcript using the configured chat model.
9901059
@@ -1043,7 +1112,7 @@ async def chat_about_transcript(self, full_prompt: str, username: str, model: st
10431112 input_data = full_prompt ,
10441113 output_data = response ,
10451114 processing_time = processing_time ,
1046- write_to_db = False
1115+ write_to_db = False ,
10471116 )
10481117
10491118 # Write accumulated chat costs to the database
0 commit comments