diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 1106a416..9001a37e 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -28,7 +28,6 @@ import tempfile import threading import time from configparser import ConfigParser -from datetime import datetime from hashlib import sha256, md5 from importlib import import_module from pathlib import Path @@ -842,39 +841,7 @@ class Client(Methods, BaseClient): final_file_path = "" try: - data, file_name, done, progress, progress_args, path = packet - - directory, file_name = os.path.split(file_name) - directory = directory or "downloads" - - media_type_str = Client.MEDIA_TYPE_ID[data.media_type] - - file_name = file_name or data.file_name - - if not file_name: - guessed_extension = self.guess_extension(data.mime_type) - - if data.media_type in (0, 1, 2, 14): - extension = ".jpg" - elif data.media_type == 3: - extension = guessed_extension or ".ogg" - elif data.media_type in (4, 10, 13): - extension = guessed_extension or ".mp4" - elif data.media_type == 5: - extension = guessed_extension or ".zip" - elif data.media_type == 8: - extension = guessed_extension or ".webp" - elif data.media_type == 9: - extension = guessed_extension or ".mp3" - else: - continue - - file_name = "{}_{}_{}{}".format( - media_type_str, - datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"), - self.rnd_id(), - extension - ) + data, directory, file_name, done, progress, progress_args, path = packet temp_file_path = self.get_file( media_type=data.media_type, diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index c8d1beab..def290e6 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -19,6 +19,8 @@ import os import platform import re +import sys +from pathlib import Path from queue import Queue from threading import Lock @@ -45,6 +47,8 @@ class BaseClient: LANG_CODE = "en" + PARENT_DIR = Path(sys.argv[0]).parent + INVITE_LINK_RE = re.compile(r"^(?:https?://)?(?:www\.)?(?:t(?:elegram)?\.(?:org|me|dog)/joinchat/)([\w-]+)$") BOT_TOKEN_RE = re.compile(r"^\d+:[\w-]+$") DIALOGS_AT_ONCE = 100 @@ -52,8 +56,8 @@ class BaseClient: DOWNLOAD_WORKERS = 1 OFFLINE_SLEEP = 900 WORKERS = 4 - WORKDIR = "." - CONFIG_FILE = "./config.ini" + WORKDIR = PARENT_DIR + CONFIG_FILE = PARENT_DIR / "config.ini" MEDIA_TYPE_ID = { 0: "photo_thumbnail", diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py index bd8de2d6..143349f7 100644 --- a/pyrogram/client/methods/messages/download_media.py +++ b/pyrogram/client/methods/messages/download_media.py @@ -17,7 +17,10 @@ # along with Pyrogram. If not, see . import binascii +import os import struct +import time +from datetime import datetime from threading import Event from typing import Union @@ -25,12 +28,14 @@ import pyrogram from pyrogram.client.ext import BaseClient, FileData, utils from pyrogram.errors import FileIdInvalid +DEFAULT_DOWNLOAD_DIR = "downloads/" + class DownloadMedia(BaseClient): def download_media( self, message: Union["pyrogram.Message", str], - file_name: str = "", + file_name: str = DEFAULT_DOWNLOAD_DIR, block: bool = True, progress: callable = None, progress_args: tuple = () @@ -169,7 +174,40 @@ class DownloadMedia(BaseClient): done = Event() path = [None] - self.download_queue.put((data, file_name, done, progress, progress_args, path)) + directory, file_name = os.path.split(file_name) + file_name = file_name or data.file_name or "" + + if not os.path.isabs(file_name): + directory = self.PARENT_DIR / (directory or DEFAULT_DOWNLOAD_DIR) + + media_type_str = self.MEDIA_TYPE_ID[data.media_type] + + if not file_name: + guessed_extension = self.guess_extension(data.mime_type) + + if data.media_type in (0, 1, 2, 14): + extension = ".jpg" + elif data.media_type == 3: + extension = guessed_extension or ".ogg" + elif data.media_type in (4, 10, 13): + extension = guessed_extension or ".mp4" + elif data.media_type == 5: + extension = guessed_extension or ".zip" + elif data.media_type == 8: + extension = guessed_extension or ".webp" + elif data.media_type == 9: + extension = guessed_extension or ".mp3" + else: + extension = ".unknown" + + file_name = "{}_{}_{}{}".format( + media_type_str, + datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"), + self.rnd_id(), + extension + ) + + self.download_queue.put((data, directory, file_name, done, progress, progress_args, path)) if block: done.wait()