diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py
index 0d1cde19..468413b0 100644
--- a/pyrogram/client/client.py
+++ b/pyrogram/client/client.py
@@ -17,7 +17,6 @@
# along with Pyrogram. If not, see .
import base64
-import binascii
import json
import logging
import math
@@ -25,7 +24,6 @@ import mimetypes
import os
import re
import shutil
-import struct
import tempfile
import threading
import time
@@ -49,7 +47,7 @@ from pyrogram.errors import (
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
- VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied,
+ VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
PasswordRecoveryNa, PasswordEmpty
)
from pyrogram.session import Auth, Session
@@ -829,85 +827,59 @@ class Client(Methods, BaseClient):
log.debug("{} started".format(name))
while True:
- media = self.download_queue.get()
+ packet = self.download_queue.get()
- if media is None:
+ if packet is None:
break
temp_file_path = ""
final_file_path = ""
try:
- media, file_name, done, progress, progress_args, path = media
-
- file_id = media.file_id
- size = media.file_size
+ data, file_name, done, progress, progress_args, path = packet
+ data = data # type: BaseClient.FileData
directory, file_name = os.path.split(file_name)
directory = directory or "downloads"
- try:
- decoded = utils.decode(file_id)
- fmt = " 24 else " 24:
- volume_id = unpacked[4]
- secret = unpacked[5]
- local_id = unpacked[6]
+ if not data.file_name:
+ guessed_extension = self.guess_extension(data.mime_type)
- media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
-
- if media_type_str is None:
- raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
-
- file_name = file_name or getattr(media, "file_name", None)
-
- if not file_name:
- guessed_extension = self.guess_extension(media.mime_type)
-
- if media_type in (0, 1, 2):
+ if data.media_type in (0, 1, 2, 14):
extension = ".jpg"
- elif media_type == 3:
+ elif data.media_type == 3:
extension = guessed_extension or ".ogg"
- elif media_type in (4, 10, 13):
+ elif data.media_type in (4, 10, 13):
extension = guessed_extension or ".mp4"
- elif media_type == 5:
+ elif data.media_type == 5:
extension = guessed_extension or ".zip"
- elif media_type == 8:
+ elif data.media_type == 8:
extension = guessed_extension or ".webp"
- elif media_type == 9:
+ elif data.media_type == 9:
extension = guessed_extension or ".mp3"
else:
continue
file_name = "{}_{}_{}{}".format(
media_type_str,
- datetime.fromtimestamp(
- getattr(media, "date", None) or time.time()
- ).strftime("%Y-%m-%d_%H-%M-%S"),
+ datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)
temp_file_path = self.get_file(
- dc_id=dc_id,
- id=id,
- access_hash=access_hash,
- volume_id=volume_id,
- local_id=local_id,
- secret=secret,
- size=size,
+ media_type=data.media_type,
+ dc_id=data.dc_id,
+ file_id=data.file_id,
+ access_hash=data.access_hash,
+ thumb_size=data.thumb_size,
+ peer_id=data.peer_id,
+ volume_id=data.volume_id,
+ local_id=data.local_id,
+ file_size=data.file_size,
+ is_big=data.is_big,
progress=progress,
progress_args=progress_args
)
@@ -1549,16 +1521,21 @@ class Client(Methods, BaseClient):
finally:
session.stop()
- def get_file(self,
- dc_id: int,
- id: int = None,
- access_hash: int = None,
- volume_id: int = None,
- local_id: int = None,
- secret: int = None,
- size: int = None,
- progress: callable = None,
- progress_args: tuple = ()) -> str:
+ def get_file(
+ self,
+ media_type: int,
+ dc_id: int,
+ file_id: int,
+ access_hash: int,
+ thumb_size: str,
+ peer_id: int,
+ volume_id: int,
+ local_id: int,
+ file_size: int,
+ is_big: bool,
+ progress: callable,
+ progress_args: tuple = ()
+ ) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
@@ -1599,18 +1576,33 @@ class Client(Methods, BaseClient):
self.media_sessions[dc_id] = session
- if volume_id: # Photos are accessed by volume_id, local_id, secret
- location = types.InputFileLocation(
+ if media_type == 1:
+ location = types.InputPeerPhotoFileLocation(
+ peer=self.resolve_peer(peer_id),
volume_id=volume_id,
local_id=local_id,
- secret=secret,
- file_reference=b""
+ big=is_big or None
)
- else: # Any other file can be more easily accessed by id and access_hash
- location = types.InputDocumentFileLocation(
- id=id,
+ elif media_type in (0, 2):
+ location = types.InputPhotoFileLocation(
+ id=file_id,
access_hash=access_hash,
- file_reference=b""
+ file_reference=b"",
+ thumb_size=thumb_size
+ )
+ elif media_type == 14:
+ location = types.InputDocumentFileLocation(
+ id=file_id,
+ access_hash=access_hash,
+ file_reference=b"",
+ thumb_size=thumb_size
+ )
+ else:
+ location = types.InputDocumentFileLocation(
+ id=file_id,
+ access_hash=access_hash,
+ file_reference=b"",
+ thumb_size=""
)
limit = 1024 * 1024
@@ -1641,7 +1633,14 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
- progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
+ progress(
+ self,
+ min(offset, file_size)
+ if file_size != 0
+ else offset,
+ file_size,
+ *progress_args
+ )
r = session.send(
functions.upload.GetFile(
@@ -1723,7 +1722,14 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
- progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
+ progress(
+ self,
+ min(offset, file_size)
+ if file_size != 0
+ else offset,
+ file_size,
+ *progress_args
+ )
if len(chunk) < limit:
break
diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py
index a3816bdf..3397c8d3 100644
--- a/pyrogram/client/ext/base_client.py
+++ b/pyrogram/client/ext/base_client.py
@@ -19,6 +19,7 @@
import os
import platform
import re
+from collections import namedtuple
from queue import Queue
from threading import Lock
@@ -56,7 +57,7 @@ class BaseClient:
CONFIG_FILE = "./config.ini"
MEDIA_TYPE_ID = {
- 0: "thumbnail",
+ 0: "photo_thumbnail",
1: "chat_photo",
2: "photo",
3: "voice",
@@ -65,7 +66,8 @@ class BaseClient:
8: "sticker",
9: "audio",
10: "animation",
- 13: "video_note"
+ 13: "video_note",
+ 14: "document_thumbnail"
}
mime_types_to_extensions = {}
@@ -82,6 +84,10 @@ class BaseClient:
mime_types_to_extensions[mime_type] = " ".join(extensions)
+ fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id",
+ "is_big", "file_size", "mime_type", "file_name", "date")
+ FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields))
+
def __init__(self):
self.is_bot = None
self.dc_id = None
diff --git a/pyrogram/client/methods/messages/download_media.py b/pyrogram/client/methods/messages/download_media.py
index c21a95bf..bba5afd5 100644
--- a/pyrogram/client/methods/messages/download_media.py
+++ b/pyrogram/client/methods/messages/download_media.py
@@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
+import binascii
+import struct
from threading import Event
from typing import Union
import pyrogram
-from pyrogram.client.ext import BaseClient
+from pyrogram.client.ext import BaseClient, utils
+from pyrogram.errors import FileIdInvalid
class DownloadMedia(BaseClient):
@@ -81,67 +84,91 @@ class DownloadMedia(BaseClient):
``ValueError`` if the message doesn't contain any downloadable media
"""
error_message = "This message doesn't contain any downloadable media"
+ available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note")
+
+ file_size = None
+ mime_type = None
+ date = None
if isinstance(message, pyrogram.Message):
- if message.photo:
- media = pyrogram.Document(
- file_id=message.photo.sizes[-1].file_id,
- file_size=message.photo.sizes[-1].file_size,
- mime_type="",
- date=message.photo.date,
- client=self
- )
- elif message.audio:
- media = message.audio
- elif message.document:
- media = message.document
- elif message.video:
- media = message.video
- elif message.voice:
- media = message.voice
- elif message.video_note:
- media = message.video_note
- elif message.sticker:
- media = message.sticker
- elif message.animation:
- media = message.animation
+ for kind in available_media:
+ media = getattr(message, kind, None)
+
+ if media is not None:
+ break
else:
raise ValueError(error_message)
- elif isinstance(message, (
- pyrogram.Photo,
- pyrogram.PhotoSize,
- pyrogram.Audio,
- pyrogram.Document,
- pyrogram.Video,
- pyrogram.Voice,
- pyrogram.VideoNote,
- pyrogram.Sticker,
- pyrogram.Animation
- )):
- if isinstance(message, pyrogram.Photo):
- media = pyrogram.Document(
- file_id=message.sizes[-1].file_id,
- file_size=message.sizes[-1].file_size,
- mime_type="",
- date=message.date,
- client=self
+ else:
+ media = message
+
+ if isinstance(media, str):
+ file_id_str = media
+ else:
+ file_id_str = media.file_id
+ file_name = getattr(media, "file_name", "")
+ file_size = getattr(media, "file_size", None)
+ mime_type = getattr(media, "mime_type", None)
+ date = getattr(media, "date", None)
+
+ data = self.FileData(
+ file_name=file_name,
+ file_size=file_size,
+ mime_type=mime_type,
+ date=date
+ )
+
+ def get_existing_attributes() -> dict:
+ return dict(filter(lambda x: x[1] is not None, data._asdict().items()))
+
+ try:
+ decoded = utils.decode(file_id_str)
+ media_type = decoded[0]
+
+ if media_type == 1:
+ unpacked = struct.unpack("