1
1
import asyncio
2
- from asyncio import sleep
2
+ from asyncio import sleep , Task , Semaphore , create_task
3
3
from os .path import relpath
4
4
from typing import Union , Optional
5
5
@@ -22,6 +22,8 @@ def __init__(self, client: Client, config: ExportConfig, media_dict: dict, progr
22
22
self .progress = progress
23
23
24
24
self ._running = False
25
+ self ._downloading : dict [Union [str , int ], ...] = {}
26
+ self ._sem = Semaphore (self .config .max_concurrent_downloads )
25
27
26
28
def add (self , file_id : str , download_dir : str , out_id : Union [str , int ]) -> None :
27
29
if out_id in self .all_ids : return
@@ -31,28 +33,34 @@ def add(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
31
33
self ._status ()
32
34
33
35
async def _download (self , file_id : str , download_dir : str , out_id : Union [str , int ]) -> None :
34
- try :
35
- path = await self .client .download_media (file_id , file_name = download_dir )
36
- except RPCError :
37
- return
36
+ async with self ._sem :
37
+ try :
38
+ path = await self .client .download_media (file_id , file_name = download_dir )
39
+ except RPCError :
40
+ return
41
+ finally :
42
+ self ._downloading .pop (out_id , None )
43
+ self .ids .discard (out_id )
44
+
38
45
path = relpath (path , self .config .output_dir .absolute ())
39
46
self .output [out_id ] = path
40
47
41
48
def _status (self , status : str = None ) -> None :
42
49
with self .progress .update ():
43
50
self .progress .media_status = status or self .progress .media_status
44
- self .progress .media_queue = len (self .queue )
51
+ self .progress .media_queue = len (self .queue ) + len ( self . _downloading )
45
52
46
53
async def _task (self ) -> None :
54
+ # use create_task and semaphore
55
+ downloading : dict [Union [str , int ], Task ] = {}
47
56
while self ._running :
48
- if not self .queue :
57
+ if not self .queue and not downloading :
49
58
self ._status ("Idle..." )
50
59
await sleep (.1 )
51
60
continue
52
61
self ._status ("Downloading..." )
53
- await self ._download (* self .queue [0 ])
54
- _ , _ , task_id = self .queue .pop (0 )
55
- self .ids .discard (task_id )
62
+ * args , task_id = self .queue .pop (0 )
63
+ self ._downloading [task_id ] = create_task (self ._download (* args , task_id ))
56
64
57
65
self ._status ("Stopped..." )
58
66
@@ -66,7 +74,7 @@ async def stop(self) -> None:
66
74
67
75
async def wait (self , messages : Optional [list [int ]]= None ) -> None :
68
76
messages = set (messages ) if messages is not None else None
69
- while self ._running and self .queue :
77
+ while self ._running and ( self .queue or self . _downloading ) :
70
78
if messages is not None and not messages .intersection (self .ids ):
71
79
break
72
80
await sleep (.1 )
0 commit comments