diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 0d1cde19..468413b0 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -17,7 +17,6 @@ # along with Pyrogram. If not, see . import base64 -import binascii import json import logging import math @@ -25,7 +24,6 @@ import mimetypes import os import re import shutil -import struct import tempfile import threading import time @@ -49,7 +47,7 @@ from pyrogram.errors import ( PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty, PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned, - VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied, + VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied, PasswordRecoveryNa, PasswordEmpty ) from pyrogram.session import Auth, Session @@ -829,85 +827,59 @@ class Client(Methods, BaseClient): log.debug("{} started".format(name)) while True: - media = self.download_queue.get() + packet = self.download_queue.get() - if media is None: + if packet is None: break temp_file_path = "" final_file_path = "" try: - media, file_name, done, progress, progress_args, path = media - - file_id = media.file_id - size = media.file_size + data, file_name, done, progress, progress_args, path = packet + data = data # type: BaseClient.FileData directory, file_name = os.path.split(file_name) directory = directory or "downloads" - try: - decoded = utils.decode(file_id) - fmt = " 24 else " 24: - volume_id = unpacked[4] - secret = unpacked[5] - local_id = unpacked[6] + if not data.file_name: + guessed_extension = self.guess_extension(data.mime_type) - media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None) - - if media_type_str is None: - raise FileIdInvalid("Unknown media type: {}".format(unpacked[0])) - - file_name = file_name or getattr(media, "file_name", None) - - if not file_name: - guessed_extension = self.guess_extension(media.mime_type) - - if media_type in (0, 1, 2): + if data.media_type in (0, 1, 2, 14): extension = ".jpg" - elif media_type == 3: + elif data.media_type == 3: extension = guessed_extension or ".ogg" - elif media_type in (4, 10, 13): + elif data.media_type in (4, 10, 13): extension = guessed_extension or ".mp4" - elif media_type == 5: + elif data.media_type == 5: extension = guessed_extension or ".zip" - elif media_type == 8: + elif data.media_type == 8: extension = guessed_extension or ".webp" - elif media_type == 9: + elif data.media_type == 9: extension = guessed_extension or ".mp3" else: continue file_name = "{}_{}_{}{}".format( media_type_str, - datetime.fromtimestamp( - getattr(media, "date", None) or time.time() - ).strftime("%Y-%m-%d_%H-%M-%S"), + datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"), self.rnd_id(), extension ) temp_file_path = self.get_file( - dc_id=dc_id, - id=id, - access_hash=access_hash, - volume_id=volume_id, - local_id=local_id, - secret=secret, - size=size, + media_type=data.media_type, + dc_id=data.dc_id, + file_id=data.file_id, + access_hash=data.access_hash, + thumb_size=data.thumb_size, + peer_id=data.peer_id, + volume_id=data.volume_id, + local_id=data.local_id, + file_size=data.file_size, + is_big=data.is_big, progress=progress, progress_args=progress_args ) @@ -1549,16 +1521,21 @@ class Client(Methods, BaseClient): finally: session.stop() - def get_file(self, - dc_id: int, - id: int = None, - access_hash: int = None, - volume_id: int = None, - local_id: int = None, - secret: int = None, - size: int = None, - progress: callable = None, - progress_args: tuple = ()) -> str: + def get_file( + self, + media_type: int, + dc_id: int, + file_id: int, + access_hash: int, + thumb_size: str, + peer_id: int, + volume_id: int, + local_id: int, + file_size: int, + is_big: bool, + progress: callable, + progress_args: tuple = () + ) -> str: with self.media_sessions_lock: session = self.media_sessions.get(dc_id, None) @@ -1599,18 +1576,33 @@ class Client(Methods, BaseClient): self.media_sessions[dc_id] = session - if volume_id: # Photos are accessed by volume_id, local_id, secret - location = types.InputFileLocation( + if media_type == 1: + location = types.InputPeerPhotoFileLocation( + peer=self.resolve_peer(peer_id), volume_id=volume_id, local_id=local_id, - secret=secret, - file_reference=b"" + big=is_big or None ) - else: # Any other file can be more easily accessed by id and access_hash - location = types.InputDocumentFileLocation( - id=id, + elif media_type in (0, 2): + location = types.InputPhotoFileLocation( + id=file_id, access_hash=access_hash, - file_reference=b"" + file_reference=b"", + thumb_size=thumb_size + ) + elif media_type == 14: + location = types.InputDocumentFileLocation( + id=file_id, + access_hash=access_hash, + file_reference=b"", + thumb_size=thumb_size + ) + else: + location = types.InputDocumentFileLocation( + id=file_id, + access_hash=access_hash, + file_reference=b"", + thumb_size="" ) limit = 1024 * 1024 @@ -1641,7 +1633,14 @@ class Client(Methods, BaseClient): offset += limit if progress: - progress(self, min(offset, size) if size != 0 else offset, size, *progress_args) + progress( + self, + min(offset, file_size) + if file_size != 0 + else offset, + file_size, + *progress_args + ) r = session.send( functions.upload.GetFile( @@ -1723,7 +1722,14 @@ class Client(Methods, BaseClient): offset += limit if progress: - progress(self, min(offset, size) if size != 0 else offset, size, *progress_args) + progress( + self, + min(offset, file_size) + if file_size != 0 + else offset, + file_size, + *progress_args + ) if len(chunk) < limit: break diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index a3816bdf..3397c8d3 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -19,6 +19,7 @@ import os import platform import re +from collections import namedtuple from queue import Queue from threading import Lock @@ -56,7 +57,7 @@ class BaseClient: CONFIG_FILE = "./config.ini" MEDIA_TYPE_ID = { - 0: "thumbnail", + 0: "photo_thumbnail", 1: "chat_photo", 2: "photo", 3: "voice", @@ -65,7 +66,8 @@ class BaseClient: 8: "sticker", 9: "audio", 10: "animation", - 13: "video_note" + 13: "video_note", + 14: "document_thumbnail" } mime_types_to_extensions = {} @@ -82,6 +84,10 @@ class BaseClient: mime_types_to_extensions[mime_type] = " ".join(extensions) + fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id", + "is_big", "file_size", "mime_type", "file_name", "date") + FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields)) + def __init__(self): self.is_bot = None self.dc_id = None diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py index c21a95bf..bba5afd5 100644 --- a/pyrogram/client/methods/messages/download_media.py +++ b/pyrogram/client/methods/messages/download_media.py @@ -16,11 +16,14 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import binascii +import struct from threading import Event from typing import Union import pyrogram -from pyrogram.client.ext import BaseClient +from pyrogram.client.ext import BaseClient, utils +from pyrogram.errors import FileIdInvalid class DownloadMedia(BaseClient): @@ -81,67 +84,91 @@ class DownloadMedia(BaseClient): ``ValueError`` if the message doesn't contain any downloadable media """ error_message = "This message doesn't contain any downloadable media" + available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note") + + file_size = None + mime_type = None + date = None if isinstance(message, pyrogram.Message): - if message.photo: - media = pyrogram.Document( - file_id=message.photo.sizes[-1].file_id, - file_size=message.photo.sizes[-1].file_size, - mime_type="", - date=message.photo.date, - client=self - ) - elif message.audio: - media = message.audio - elif message.document: - media = message.document - elif message.video: - media = message.video - elif message.voice: - media = message.voice - elif message.video_note: - media = message.video_note - elif message.sticker: - media = message.sticker - elif message.animation: - media = message.animation + for kind in available_media: + media = getattr(message, kind, None) + + if media is not None: + break else: raise ValueError(error_message) - elif isinstance(message, ( - pyrogram.Photo, - pyrogram.PhotoSize, - pyrogram.Audio, - pyrogram.Document, - pyrogram.Video, - pyrogram.Voice, - pyrogram.VideoNote, - pyrogram.Sticker, - pyrogram.Animation - )): - if isinstance(message, pyrogram.Photo): - media = pyrogram.Document( - file_id=message.sizes[-1].file_id, - file_size=message.sizes[-1].file_size, - mime_type="", - date=message.date, - client=self + else: + media = message + + if isinstance(media, str): + file_id_str = media + else: + file_id_str = media.file_id + file_name = getattr(media, "file_name", "") + file_size = getattr(media, "file_size", None) + mime_type = getattr(media, "mime_type", None) + date = getattr(media, "date", None) + + data = self.FileData( + file_name=file_name, + file_size=file_size, + mime_type=mime_type, + date=date + ) + + def get_existing_attributes() -> dict: + return dict(filter(lambda x: x[1] is not None, data._asdict().items())) + + try: + decoded = utils.decode(file_id_str) + media_type = decoded[0] + + if media_type == 1: + unpacked = struct.unpack("