mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 12:57:52 +00:00
Rework download_media to accommodate L100 changes
This commit is contained in:
parent
3208b22849
commit
55599e33c6
@ -17,7 +17,6 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -25,7 +24,6 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import struct
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -49,7 +47,7 @@ from pyrogram.errors import (
|
|||||||
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
|
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
|
||||||
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
|
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
|
||||||
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
|
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
|
||||||
VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied,
|
VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
|
||||||
PasswordRecoveryNa, PasswordEmpty
|
PasswordRecoveryNa, PasswordEmpty
|
||||||
)
|
)
|
||||||
from pyrogram.session import Auth, Session
|
from pyrogram.session import Auth, Session
|
||||||
@ -829,85 +827,59 @@ class Client(Methods, BaseClient):
|
|||||||
log.debug("{} started".format(name))
|
log.debug("{} started".format(name))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
media = self.download_queue.get()
|
packet = self.download_queue.get()
|
||||||
|
|
||||||
if media is None:
|
if packet is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
temp_file_path = ""
|
temp_file_path = ""
|
||||||
final_file_path = ""
|
final_file_path = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
media, file_name, done, progress, progress_args, path = media
|
data, file_name, done, progress, progress_args, path = packet
|
||||||
|
data = data # type: BaseClient.FileData
|
||||||
file_id = media.file_id
|
|
||||||
size = media.file_size
|
|
||||||
|
|
||||||
directory, file_name = os.path.split(file_name)
|
directory, file_name = os.path.split(file_name)
|
||||||
directory = directory or "downloads"
|
directory = directory or "downloads"
|
||||||
|
|
||||||
try:
|
media_type_str = Client.MEDIA_TYPE_ID[data.media_type]
|
||||||
decoded = utils.decode(file_id)
|
|
||||||
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
|
|
||||||
unpacked = struct.unpack(fmt, decoded)
|
|
||||||
except (AssertionError, binascii.Error, struct.error):
|
|
||||||
raise FileIdInvalid from None
|
|
||||||
else:
|
|
||||||
media_type = unpacked[0]
|
|
||||||
dc_id = unpacked[1]
|
|
||||||
id = unpacked[2]
|
|
||||||
access_hash = unpacked[3]
|
|
||||||
volume_id = None
|
|
||||||
secret = None
|
|
||||||
local_id = None
|
|
||||||
|
|
||||||
if len(decoded) > 24:
|
if not data.file_name:
|
||||||
volume_id = unpacked[4]
|
guessed_extension = self.guess_extension(data.mime_type)
|
||||||
secret = unpacked[5]
|
|
||||||
local_id = unpacked[6]
|
|
||||||
|
|
||||||
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
|
if data.media_type in (0, 1, 2, 14):
|
||||||
|
|
||||||
if media_type_str is None:
|
|
||||||
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
|
|
||||||
|
|
||||||
file_name = file_name or getattr(media, "file_name", None)
|
|
||||||
|
|
||||||
if not file_name:
|
|
||||||
guessed_extension = self.guess_extension(media.mime_type)
|
|
||||||
|
|
||||||
if media_type in (0, 1, 2):
|
|
||||||
extension = ".jpg"
|
extension = ".jpg"
|
||||||
elif media_type == 3:
|
elif data.media_type == 3:
|
||||||
extension = guessed_extension or ".ogg"
|
extension = guessed_extension or ".ogg"
|
||||||
elif media_type in (4, 10, 13):
|
elif data.media_type in (4, 10, 13):
|
||||||
extension = guessed_extension or ".mp4"
|
extension = guessed_extension or ".mp4"
|
||||||
elif media_type == 5:
|
elif data.media_type == 5:
|
||||||
extension = guessed_extension or ".zip"
|
extension = guessed_extension or ".zip"
|
||||||
elif media_type == 8:
|
elif data.media_type == 8:
|
||||||
extension = guessed_extension or ".webp"
|
extension = guessed_extension or ".webp"
|
||||||
elif media_type == 9:
|
elif data.media_type == 9:
|
||||||
extension = guessed_extension or ".mp3"
|
extension = guessed_extension or ".mp3"
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
file_name = "{}_{}_{}{}".format(
|
file_name = "{}_{}_{}{}".format(
|
||||||
media_type_str,
|
media_type_str,
|
||||||
datetime.fromtimestamp(
|
datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
|
||||||
getattr(media, "date", None) or time.time()
|
|
||||||
).strftime("%Y-%m-%d_%H-%M-%S"),
|
|
||||||
self.rnd_id(),
|
self.rnd_id(),
|
||||||
extension
|
extension
|
||||||
)
|
)
|
||||||
|
|
||||||
temp_file_path = self.get_file(
|
temp_file_path = self.get_file(
|
||||||
dc_id=dc_id,
|
media_type=data.media_type,
|
||||||
id=id,
|
dc_id=data.dc_id,
|
||||||
access_hash=access_hash,
|
file_id=data.file_id,
|
||||||
volume_id=volume_id,
|
access_hash=data.access_hash,
|
||||||
local_id=local_id,
|
thumb_size=data.thumb_size,
|
||||||
secret=secret,
|
peer_id=data.peer_id,
|
||||||
size=size,
|
volume_id=data.volume_id,
|
||||||
|
local_id=data.local_id,
|
||||||
|
file_size=data.file_size,
|
||||||
|
is_big=data.is_big,
|
||||||
progress=progress,
|
progress=progress,
|
||||||
progress_args=progress_args
|
progress_args=progress_args
|
||||||
)
|
)
|
||||||
@ -1549,16 +1521,21 @@ class Client(Methods, BaseClient):
|
|||||||
finally:
|
finally:
|
||||||
session.stop()
|
session.stop()
|
||||||
|
|
||||||
def get_file(self,
|
def get_file(
|
||||||
dc_id: int,
|
self,
|
||||||
id: int = None,
|
media_type: int,
|
||||||
access_hash: int = None,
|
dc_id: int,
|
||||||
volume_id: int = None,
|
file_id: int,
|
||||||
local_id: int = None,
|
access_hash: int,
|
||||||
secret: int = None,
|
thumb_size: str,
|
||||||
size: int = None,
|
peer_id: int,
|
||||||
progress: callable = None,
|
volume_id: int,
|
||||||
progress_args: tuple = ()) -> str:
|
local_id: int,
|
||||||
|
file_size: int,
|
||||||
|
is_big: bool,
|
||||||
|
progress: callable,
|
||||||
|
progress_args: tuple = ()
|
||||||
|
) -> str:
|
||||||
with self.media_sessions_lock:
|
with self.media_sessions_lock:
|
||||||
session = self.media_sessions.get(dc_id, None)
|
session = self.media_sessions.get(dc_id, None)
|
||||||
|
|
||||||
@ -1599,18 +1576,33 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
self.media_sessions[dc_id] = session
|
self.media_sessions[dc_id] = session
|
||||||
|
|
||||||
if volume_id: # Photos are accessed by volume_id, local_id, secret
|
if media_type == 1:
|
||||||
location = types.InputFileLocation(
|
location = types.InputPeerPhotoFileLocation(
|
||||||
|
peer=self.resolve_peer(peer_id),
|
||||||
volume_id=volume_id,
|
volume_id=volume_id,
|
||||||
local_id=local_id,
|
local_id=local_id,
|
||||||
secret=secret,
|
big=is_big or None
|
||||||
file_reference=b""
|
|
||||||
)
|
)
|
||||||
else: # Any other file can be more easily accessed by id and access_hash
|
elif media_type in (0, 2):
|
||||||
location = types.InputDocumentFileLocation(
|
location = types.InputPhotoFileLocation(
|
||||||
id=id,
|
id=file_id,
|
||||||
access_hash=access_hash,
|
access_hash=access_hash,
|
||||||
file_reference=b""
|
file_reference=b"",
|
||||||
|
thumb_size=thumb_size
|
||||||
|
)
|
||||||
|
elif media_type == 14:
|
||||||
|
location = types.InputDocumentFileLocation(
|
||||||
|
id=file_id,
|
||||||
|
access_hash=access_hash,
|
||||||
|
file_reference=b"",
|
||||||
|
thumb_size=thumb_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
location = types.InputDocumentFileLocation(
|
||||||
|
id=file_id,
|
||||||
|
access_hash=access_hash,
|
||||||
|
file_reference=b"",
|
||||||
|
thumb_size=""
|
||||||
)
|
)
|
||||||
|
|
||||||
limit = 1024 * 1024
|
limit = 1024 * 1024
|
||||||
@ -1641,7 +1633,14 @@ class Client(Methods, BaseClient):
|
|||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
|
progress(
|
||||||
|
self,
|
||||||
|
min(offset, file_size)
|
||||||
|
if file_size != 0
|
||||||
|
else offset,
|
||||||
|
file_size,
|
||||||
|
*progress_args
|
||||||
|
)
|
||||||
|
|
||||||
r = session.send(
|
r = session.send(
|
||||||
functions.upload.GetFile(
|
functions.upload.GetFile(
|
||||||
@ -1723,7 +1722,14 @@ class Client(Methods, BaseClient):
|
|||||||
offset += limit
|
offset += limit
|
||||||
|
|
||||||
if progress:
|
if progress:
|
||||||
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
|
progress(
|
||||||
|
self,
|
||||||
|
min(offset, file_size)
|
||||||
|
if file_size != 0
|
||||||
|
else offset,
|
||||||
|
file_size,
|
||||||
|
*progress_args
|
||||||
|
)
|
||||||
|
|
||||||
if len(chunk) < limit:
|
if len(chunk) < limit:
|
||||||
break
|
break
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
|
from collections import namedtuple
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ class BaseClient:
|
|||||||
CONFIG_FILE = "./config.ini"
|
CONFIG_FILE = "./config.ini"
|
||||||
|
|
||||||
MEDIA_TYPE_ID = {
|
MEDIA_TYPE_ID = {
|
||||||
0: "thumbnail",
|
0: "photo_thumbnail",
|
||||||
1: "chat_photo",
|
1: "chat_photo",
|
||||||
2: "photo",
|
2: "photo",
|
||||||
3: "voice",
|
3: "voice",
|
||||||
@ -65,7 +66,8 @@ class BaseClient:
|
|||||||
8: "sticker",
|
8: "sticker",
|
||||||
9: "audio",
|
9: "audio",
|
||||||
10: "animation",
|
10: "animation",
|
||||||
13: "video_note"
|
13: "video_note",
|
||||||
|
14: "document_thumbnail"
|
||||||
}
|
}
|
||||||
|
|
||||||
mime_types_to_extensions = {}
|
mime_types_to_extensions = {}
|
||||||
@ -82,6 +84,10 @@ class BaseClient:
|
|||||||
|
|
||||||
mime_types_to_extensions[mime_type] = " ".join(extensions)
|
mime_types_to_extensions[mime_type] = " ".join(extensions)
|
||||||
|
|
||||||
|
fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id",
|
||||||
|
"is_big", "file_size", "mime_type", "file_name", "date")
|
||||||
|
FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields))
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.is_bot = None
|
self.is_bot = None
|
||||||
self.dc_id = None
|
self.dc_id = None
|
||||||
|
@ -16,11 +16,14 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import binascii
|
||||||
|
import struct
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram.client.ext import BaseClient
|
from pyrogram.client.ext import BaseClient, utils
|
||||||
|
from pyrogram.errors import FileIdInvalid
|
||||||
|
|
||||||
|
|
||||||
class DownloadMedia(BaseClient):
|
class DownloadMedia(BaseClient):
|
||||||
@ -81,67 +84,91 @@ class DownloadMedia(BaseClient):
|
|||||||
``ValueError`` if the message doesn't contain any downloadable media
|
``ValueError`` if the message doesn't contain any downloadable media
|
||||||
"""
|
"""
|
||||||
error_message = "This message doesn't contain any downloadable media"
|
error_message = "This message doesn't contain any downloadable media"
|
||||||
|
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note")
|
||||||
|
|
||||||
|
file_size = None
|
||||||
|
mime_type = None
|
||||||
|
date = None
|
||||||
|
|
||||||
if isinstance(message, pyrogram.Message):
|
if isinstance(message, pyrogram.Message):
|
||||||
if message.photo:
|
for kind in available_media:
|
||||||
media = pyrogram.Document(
|
media = getattr(message, kind, None)
|
||||||
file_id=message.photo.sizes[-1].file_id,
|
|
||||||
file_size=message.photo.sizes[-1].file_size,
|
if media is not None:
|
||||||
mime_type="",
|
break
|
||||||
date=message.photo.date,
|
|
||||||
client=self
|
|
||||||
)
|
|
||||||
elif message.audio:
|
|
||||||
media = message.audio
|
|
||||||
elif message.document:
|
|
||||||
media = message.document
|
|
||||||
elif message.video:
|
|
||||||
media = message.video
|
|
||||||
elif message.voice:
|
|
||||||
media = message.voice
|
|
||||||
elif message.video_note:
|
|
||||||
media = message.video_note
|
|
||||||
elif message.sticker:
|
|
||||||
media = message.sticker
|
|
||||||
elif message.animation:
|
|
||||||
media = message.animation
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
elif isinstance(message, (
|
else:
|
||||||
pyrogram.Photo,
|
media = message
|
||||||
pyrogram.PhotoSize,
|
|
||||||
pyrogram.Audio,
|
if isinstance(media, str):
|
||||||
pyrogram.Document,
|
file_id_str = media
|
||||||
pyrogram.Video,
|
else:
|
||||||
pyrogram.Voice,
|
file_id_str = media.file_id
|
||||||
pyrogram.VideoNote,
|
file_name = getattr(media, "file_name", "")
|
||||||
pyrogram.Sticker,
|
file_size = getattr(media, "file_size", None)
|
||||||
pyrogram.Animation
|
mime_type = getattr(media, "mime_type", None)
|
||||||
)):
|
date = getattr(media, "date", None)
|
||||||
if isinstance(message, pyrogram.Photo):
|
|
||||||
media = pyrogram.Document(
|
data = self.FileData(
|
||||||
file_id=message.sizes[-1].file_id,
|
file_name=file_name,
|
||||||
file_size=message.sizes[-1].file_size,
|
file_size=file_size,
|
||||||
mime_type="",
|
mime_type=mime_type,
|
||||||
date=message.date,
|
date=date
|
||||||
client=self
|
)
|
||||||
|
|
||||||
|
def get_existing_attributes() -> dict:
|
||||||
|
return dict(filter(lambda x: x[1] is not None, data._asdict().items()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = utils.decode(file_id_str)
|
||||||
|
media_type = decoded[0]
|
||||||
|
|
||||||
|
if media_type == 1:
|
||||||
|
unpacked = struct.unpack("<iiqqib", decoded)
|
||||||
|
dc_id, peer_id, volume_id, local_id, is_big = unpacked[1:]
|
||||||
|
|
||||||
|
data = self.FileData(
|
||||||
|
**get_existing_attributes(),
|
||||||
|
media_type=media_type,
|
||||||
|
dc_id=dc_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
volume_id=volume_id,
|
||||||
|
local_id=local_id,
|
||||||
|
is_big=bool(is_big)
|
||||||
|
)
|
||||||
|
elif media_type in (0, 2, 14):
|
||||||
|
unpacked = struct.unpack("<iiqqc", decoded)
|
||||||
|
dc_id, file_id, access_hash, thumb_size = unpacked[1:]
|
||||||
|
|
||||||
|
data = self.FileData(
|
||||||
|
**get_existing_attributes(),
|
||||||
|
media_type=media_type,
|
||||||
|
dc_id=dc_id,
|
||||||
|
file_id=file_id,
|
||||||
|
access_hash=access_hash,
|
||||||
|
thumb_size=thumb_size.decode()
|
||||||
|
)
|
||||||
|
elif media_type in (3, 4, 5, 8, 9, 10, 13):
|
||||||
|
unpacked = struct.unpack("<iiqq", decoded)
|
||||||
|
dc_id, file_id, access_hash = unpacked[1:]
|
||||||
|
|
||||||
|
data = self.FileData(
|
||||||
|
**get_existing_attributes(),
|
||||||
|
media_type=media_type,
|
||||||
|
dc_id=dc_id,
|
||||||
|
file_id=file_id,
|
||||||
|
access_hash=access_hash
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
media = message
|
raise ValueError("Unknown media type: {}".format(file_id_str))
|
||||||
elif isinstance(message, str):
|
except (AssertionError, binascii.Error, struct.error):
|
||||||
media = pyrogram.Document(
|
raise FileIdInvalid from None
|
||||||
file_id=message,
|
|
||||||
file_size=0,
|
|
||||||
mime_type="",
|
|
||||||
client=self
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(error_message)
|
|
||||||
|
|
||||||
done = Event()
|
done = Event()
|
||||||
path = [None]
|
path = [None]
|
||||||
|
|
||||||
self.download_queue.put((media, file_name, done, progress, progress_args, path))
|
self.download_queue.put((data, file_name, done, progress, progress_args, path))
|
||||||
|
|
||||||
if block:
|
if block:
|
||||||
done.wait()
|
done.wait()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user