2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 12:57:52 +00:00

Rework download_media to accommodate L100 changes

This commit is contained in:
Dan 2019-05-29 10:40:37 +02:00
parent 3208b22849
commit 55599e33c6
3 changed files with 166 additions and 127 deletions

View File

@ -17,7 +17,6 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import base64 import base64
import binascii
import json import json
import logging import logging
import math import math
@ -25,7 +24,6 @@ import mimetypes
import os import os
import re import re
import shutil import shutil
import struct
import tempfile import tempfile
import threading import threading
import time import time
@ -49,7 +47,7 @@ from pyrogram.errors import (
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty, PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned, PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied, VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
PasswordRecoveryNa, PasswordEmpty PasswordRecoveryNa, PasswordEmpty
) )
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
@ -829,85 +827,59 @@ class Client(Methods, BaseClient):
log.debug("{} started".format(name)) log.debug("{} started".format(name))
while True: while True:
media = self.download_queue.get() packet = self.download_queue.get()
if media is None: if packet is None:
break break
temp_file_path = "" temp_file_path = ""
final_file_path = "" final_file_path = ""
try: try:
media, file_name, done, progress, progress_args, path = media data, file_name, done, progress, progress_args, path = packet
data = data # type: BaseClient.FileData
file_id = media.file_id
size = media.file_size
directory, file_name = os.path.split(file_name) directory, file_name = os.path.split(file_name)
directory = directory or "downloads" directory = directory or "downloads"
try: media_type_str = Client.MEDIA_TYPE_ID[data.media_type]
decoded = utils.decode(file_id)
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
unpacked = struct.unpack(fmt, decoded)
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
else:
media_type = unpacked[0]
dc_id = unpacked[1]
id = unpacked[2]
access_hash = unpacked[3]
volume_id = None
secret = None
local_id = None
if len(decoded) > 24: if not data.file_name:
volume_id = unpacked[4] guessed_extension = self.guess_extension(data.mime_type)
secret = unpacked[5]
local_id = unpacked[6]
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None) if data.media_type in (0, 1, 2, 14):
if media_type_str is None:
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
file_name = file_name or getattr(media, "file_name", None)
if not file_name:
guessed_extension = self.guess_extension(media.mime_type)
if media_type in (0, 1, 2):
extension = ".jpg" extension = ".jpg"
elif media_type == 3: elif data.media_type == 3:
extension = guessed_extension or ".ogg" extension = guessed_extension or ".ogg"
elif media_type in (4, 10, 13): elif data.media_type in (4, 10, 13):
extension = guessed_extension or ".mp4" extension = guessed_extension or ".mp4"
elif media_type == 5: elif data.media_type == 5:
extension = guessed_extension or ".zip" extension = guessed_extension or ".zip"
elif media_type == 8: elif data.media_type == 8:
extension = guessed_extension or ".webp" extension = guessed_extension or ".webp"
elif media_type == 9: elif data.media_type == 9:
extension = guessed_extension or ".mp3" extension = guessed_extension or ".mp3"
else: else:
continue continue
file_name = "{}_{}_{}{}".format( file_name = "{}_{}_{}{}".format(
media_type_str, media_type_str,
datetime.fromtimestamp( datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
getattr(media, "date", None) or time.time()
).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(), self.rnd_id(),
extension extension
) )
temp_file_path = self.get_file( temp_file_path = self.get_file(
dc_id=dc_id, media_type=data.media_type,
id=id, dc_id=data.dc_id,
access_hash=access_hash, file_id=data.file_id,
volume_id=volume_id, access_hash=data.access_hash,
local_id=local_id, thumb_size=data.thumb_size,
secret=secret, peer_id=data.peer_id,
size=size, volume_id=data.volume_id,
local_id=data.local_id,
file_size=data.file_size,
is_big=data.is_big,
progress=progress, progress=progress,
progress_args=progress_args progress_args=progress_args
) )
@ -1549,16 +1521,21 @@ class Client(Methods, BaseClient):
finally: finally:
session.stop() session.stop()
def get_file(self, def get_file(
self,
media_type: int,
dc_id: int, dc_id: int,
id: int = None, file_id: int,
access_hash: int = None, access_hash: int,
volume_id: int = None, thumb_size: str,
local_id: int = None, peer_id: int,
secret: int = None, volume_id: int,
size: int = None, local_id: int,
progress: callable = None, file_size: int,
progress_args: tuple = ()) -> str: is_big: bool,
progress: callable,
progress_args: tuple = ()
) -> str:
with self.media_sessions_lock: with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None) session = self.media_sessions.get(dc_id, None)
@ -1599,18 +1576,33 @@ class Client(Methods, BaseClient):
self.media_sessions[dc_id] = session self.media_sessions[dc_id] = session
if volume_id: # Photos are accessed by volume_id, local_id, secret if media_type == 1:
location = types.InputFileLocation( location = types.InputPeerPhotoFileLocation(
peer=self.resolve_peer(peer_id),
volume_id=volume_id, volume_id=volume_id,
local_id=local_id, local_id=local_id,
secret=secret, big=is_big or None
file_reference=b""
) )
else: # Any other file can be more easily accessed by id and access_hash elif media_type in (0, 2):
location = types.InputDocumentFileLocation( location = types.InputPhotoFileLocation(
id=id, id=file_id,
access_hash=access_hash, access_hash=access_hash,
file_reference=b"" file_reference=b"",
thumb_size=thumb_size
)
elif media_type == 14:
location = types.InputDocumentFileLocation(
id=file_id,
access_hash=access_hash,
file_reference=b"",
thumb_size=thumb_size
)
else:
location = types.InputDocumentFileLocation(
id=file_id,
access_hash=access_hash,
file_reference=b"",
thumb_size=""
) )
limit = 1024 * 1024 limit = 1024 * 1024
@ -1641,7 +1633,14 @@ class Client(Methods, BaseClient):
offset += limit offset += limit
if progress: if progress:
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args) progress(
self,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
r = session.send( r = session.send(
functions.upload.GetFile( functions.upload.GetFile(
@ -1723,7 +1722,14 @@ class Client(Methods, BaseClient):
offset += limit offset += limit
if progress: if progress:
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args) progress(
self,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if len(chunk) < limit: if len(chunk) < limit:
break break

View File

@ -19,6 +19,7 @@
import os import os
import platform import platform
import re import re
from collections import namedtuple
from queue import Queue from queue import Queue
from threading import Lock from threading import Lock
@ -56,7 +57,7 @@ class BaseClient:
CONFIG_FILE = "./config.ini" CONFIG_FILE = "./config.ini"
MEDIA_TYPE_ID = { MEDIA_TYPE_ID = {
0: "thumbnail", 0: "photo_thumbnail",
1: "chat_photo", 1: "chat_photo",
2: "photo", 2: "photo",
3: "voice", 3: "voice",
@ -65,7 +66,8 @@ class BaseClient:
8: "sticker", 8: "sticker",
9: "audio", 9: "audio",
10: "animation", 10: "animation",
13: "video_note" 13: "video_note",
14: "document_thumbnail"
} }
mime_types_to_extensions = {} mime_types_to_extensions = {}
@ -82,6 +84,10 @@ class BaseClient:
mime_types_to_extensions[mime_type] = " ".join(extensions) mime_types_to_extensions[mime_type] = " ".join(extensions)
fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id",
"is_big", "file_size", "mime_type", "file_name", "date")
FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields))
def __init__(self): def __init__(self):
self.is_bot = None self.is_bot = None
self.dc_id = None self.dc_id = None

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import binascii
import struct
from threading import Event from threading import Event
from typing import Union from typing import Union
import pyrogram import pyrogram
from pyrogram.client.ext import BaseClient from pyrogram.client.ext import BaseClient, utils
from pyrogram.errors import FileIdInvalid
class DownloadMedia(BaseClient): class DownloadMedia(BaseClient):
@ -81,67 +84,91 @@ class DownloadMedia(BaseClient):
``ValueError`` if the message doesn't contain any downloadable media ``ValueError`` if the message doesn't contain any downloadable media
""" """
error_message = "This message doesn't contain any downloadable media" error_message = "This message doesn't contain any downloadable media"
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note")
file_size = None
mime_type = None
date = None
if isinstance(message, pyrogram.Message): if isinstance(message, pyrogram.Message):
if message.photo: for kind in available_media:
media = pyrogram.Document( media = getattr(message, kind, None)
file_id=message.photo.sizes[-1].file_id,
file_size=message.photo.sizes[-1].file_size, if media is not None:
mime_type="", break
date=message.photo.date,
client=self
)
elif message.audio:
media = message.audio
elif message.document:
media = message.document
elif message.video:
media = message.video
elif message.voice:
media = message.voice
elif message.video_note:
media = message.video_note
elif message.sticker:
media = message.sticker
elif message.animation:
media = message.animation
else: else:
raise ValueError(error_message) raise ValueError(error_message)
elif isinstance(message, (
pyrogram.Photo,
pyrogram.PhotoSize,
pyrogram.Audio,
pyrogram.Document,
pyrogram.Video,
pyrogram.Voice,
pyrogram.VideoNote,
pyrogram.Sticker,
pyrogram.Animation
)):
if isinstance(message, pyrogram.Photo):
media = pyrogram.Document(
file_id=message.sizes[-1].file_id,
file_size=message.sizes[-1].file_size,
mime_type="",
date=message.date,
client=self
)
else: else:
media = message media = message
elif isinstance(message, str):
media = pyrogram.Document( if isinstance(media, str):
file_id=message, file_id_str = media
file_size=0, else:
mime_type="", file_id_str = media.file_id
client=self file_name = getattr(media, "file_name", "")
file_size = getattr(media, "file_size", None)
mime_type = getattr(media, "mime_type", None)
date = getattr(media, "date", None)
data = self.FileData(
file_name=file_name,
file_size=file_size,
mime_type=mime_type,
date=date
)
def get_existing_attributes() -> dict:
return dict(filter(lambda x: x[1] is not None, data._asdict().items()))
try:
decoded = utils.decode(file_id_str)
media_type = decoded[0]
if media_type == 1:
unpacked = struct.unpack("<iiqqib", decoded)
dc_id, peer_id, volume_id, local_id, is_big = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
peer_id=peer_id,
volume_id=volume_id,
local_id=local_id,
is_big=bool(is_big)
)
elif media_type in (0, 2, 14):
unpacked = struct.unpack("<iiqqc", decoded)
dc_id, file_id, access_hash, thumb_size = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
file_id=file_id,
access_hash=access_hash,
thumb_size=thumb_size.decode()
)
elif media_type in (3, 4, 5, 8, 9, 10, 13):
unpacked = struct.unpack("<iiqq", decoded)
dc_id, file_id, access_hash = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
file_id=file_id,
access_hash=access_hash
) )
else: else:
raise ValueError(error_message) raise ValueError("Unknown media type: {}".format(file_id_str))
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
done = Event() done = Event()
path = [None] path = [None]
self.download_queue.put((media, file_name, done, progress, progress_args, path)) self.download_queue.put((data, file_name, done, progress, progress_args, path))
if block: if block:
done.wait() done.wait()