Skip to content

Commit bd5f5b7

Browse files
committed
allow to download multiple media simultaneously
1 parent ac6c2e3 commit bd5f5b7

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "t-export"
3-
version = "0.1.3b3"
3+
version = "0.1.3b4"
44
description = "Telegram chats export tool."
55
authors = ["RuslanUC <dev_ruslan_uc@protonmail.com>"]
66
readme = "README.md"

texport/export_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from datetime import datetime
33
from pathlib import Path
44
from typing import Union
@@ -32,10 +32,15 @@ class ExportConfig:
3232
to_date: datetime = datetime.now()
3333
print: bool = False
3434
preload: bool = False
35+
max_concurrent_downloads: int = 4
3536

3637
def excluded_media(self) -> set[MessageMediaType]:
3738
result = set()
3839
for media_type in EXPORT_MEDIA:
3940
if not getattr(self, f"export_{media_type}"):
4041
result.add(EXPORT_MEDIA[media_type])
4142
return result
43+
44+
def __post_init__(self):
45+
if self.max_concurrent_downloads <= 0:
46+
self.max_concurrent_downloads = 4

texport/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ async def _main(session_name: str, api_id: int, api_hash: str, config: ExportCon
5252
@click.option("--documents/--no-documents", default=True, help="Download documents or not.")
5353
@click.option("--quiet", "-q", is_flag=True, default=False, help="Do not print progress to console.")
5454
@click.option("--no-preload", is_flag=True, default=False, help="Do not preload all messages.")
55+
@click.option("--max-concurrent-downloads", "-d", type=click.INT, default=4,
56+
help="Number of concurrent media downloads.")
5557
def main(
5658
session_name: str, api_id: int, api_hash: str, chat_id: str, output: str, size_limit: int, from_date: str,
5759
to_date: str, photos: bool, videos: bool, voice: bool, video_notes: bool, stickers: bool, gifs: bool,
58-
documents: bool, quiet: bool, no_preload: bool,
60+
documents: bool, quiet: bool, no_preload: bool, max_concurrent_downloads: int,
5961
) -> None:
6062
home = Path.home()
6163
texport_dir = home / ".texport"
@@ -77,6 +79,7 @@ def main(
7779
export_files=documents,
7880
print=not quiet,
7981
preload=not no_preload,
82+
max_concurrent_downloads=max_concurrent_downloads,
8083
)
8184

8285
if session_name.endswith(".session"):

texport/media_downloader.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from asyncio import sleep
2+
from asyncio import sleep, Task, Semaphore, create_task
33
from os.path import relpath
44
from typing import Union, Optional
55

@@ -22,6 +22,8 @@ def __init__(self, client: Client, config: ExportConfig, media_dict: dict, progr
2222
self.progress = progress
2323

2424
self._running = False
25+
self._downloading: dict[Union[str, int], ...] = {}
26+
self._sem = Semaphore(self.config.max_concurrent_downloads)
2527

2628
def add(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
2729
if out_id in self.all_ids: return
@@ -31,28 +33,34 @@ def add(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
3133
self._status()
3234

3335
async def _download(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
34-
try:
35-
path = await self.client.download_media(file_id, file_name=download_dir)
36-
except RPCError:
37-
return
36+
async with self._sem:
37+
try:
38+
path = await self.client.download_media(file_id, file_name=download_dir)
39+
except RPCError:
40+
return
41+
finally:
42+
self._downloading.pop(out_id, None)
43+
self.ids.discard(out_id)
44+
3845
path = relpath(path, self.config.output_dir.absolute())
3946
self.output[out_id] = path
4047

4148
def _status(self, status: str=None) -> None:
4249
with self.progress.update():
4350
self.progress.media_status = status or self.progress.media_status
44-
self.progress.media_queue = len(self.queue)
51+
self.progress.media_queue = len(self.queue) + len(self._downloading)
4552

4653
async def _task(self) -> None:
54+
# use create_task and semaphore
55+
downloading: dict[Union[str, int], Task] = {}
4756
while self._running:
48-
if not self.queue:
57+
if not self.queue and not downloading:
4958
self._status("Idle...")
5059
await sleep(.1)
5160
continue
5261
self._status("Downloading...")
53-
await self._download(*self.queue[0])
54-
_, _, task_id = self.queue.pop(0)
55-
self.ids.discard(task_id)
62+
*args, task_id = self.queue.pop(0)
63+
self._downloading[task_id] = create_task(self._download(*args, task_id))
5664

5765
self._status("Stopped...")
5866

@@ -66,7 +74,7 @@ async def stop(self) -> None:
6674

6775
async def wait(self, messages: Optional[list[int]]=None) -> None:
6876
messages = set(messages) if messages is not None else None
69-
while self._running and self.queue:
77+
while self._running and (self.queue or self._downloading):
7078
if messages is not None and not messages.intersection(self.ids):
7179
break
7280
await sleep(.1)

0 commit comments

Comments
 (0)