2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 04:48:06 +00:00

Rework download_media to accommodate L100 changes

This commit is contained in:
Dan 2019-05-29 10:40:37 +02:00
parent 3208b22849
commit 55599e33c6
3 changed files with 166 additions and 127 deletions

View File

@ -17,7 +17,6 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import base64
import binascii
import json
import logging
import math
@ -25,7 +24,6 @@ import mimetypes
import os
import re
import shutil
import struct
import tempfile
import threading
import time
@ -49,7 +47,7 @@ from pyrogram.errors import (
PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty,
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate, FileIdInvalid, ChannelPrivate, PhoneNumberOccupied,
VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied,
PasswordRecoveryNa, PasswordEmpty
)
from pyrogram.session import Auth, Session
@ -829,85 +827,59 @@ class Client(Methods, BaseClient):
log.debug("{} started".format(name))
while True:
media = self.download_queue.get()
packet = self.download_queue.get()
if media is None:
if packet is None:
break
temp_file_path = ""
final_file_path = ""
try:
media, file_name, done, progress, progress_args, path = media
file_id = media.file_id
size = media.file_size
data, file_name, done, progress, progress_args, path = packet
data = data # type: BaseClient.FileData
directory, file_name = os.path.split(file_name)
directory = directory or "downloads"
try:
decoded = utils.decode(file_id)
fmt = "<iiqqqqi" if len(decoded) > 24 else "<iiqq"
unpacked = struct.unpack(fmt, decoded)
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
else:
media_type = unpacked[0]
dc_id = unpacked[1]
id = unpacked[2]
access_hash = unpacked[3]
volume_id = None
secret = None
local_id = None
media_type_str = Client.MEDIA_TYPE_ID[data.media_type]
if len(decoded) > 24:
volume_id = unpacked[4]
secret = unpacked[5]
local_id = unpacked[6]
if not data.file_name:
guessed_extension = self.guess_extension(data.mime_type)
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
if media_type_str is None:
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
file_name = file_name or getattr(media, "file_name", None)
if not file_name:
guessed_extension = self.guess_extension(media.mime_type)
if media_type in (0, 1, 2):
if data.media_type in (0, 1, 2, 14):
extension = ".jpg"
elif media_type == 3:
elif data.media_type == 3:
extension = guessed_extension or ".ogg"
elif media_type in (4, 10, 13):
elif data.media_type in (4, 10, 13):
extension = guessed_extension or ".mp4"
elif media_type == 5:
elif data.media_type == 5:
extension = guessed_extension or ".zip"
elif media_type == 8:
elif data.media_type == 8:
extension = guessed_extension or ".webp"
elif media_type == 9:
elif data.media_type == 9:
extension = guessed_extension or ".mp3"
else:
continue
file_name = "{}_{}_{}{}".format(
media_type_str,
datetime.fromtimestamp(
getattr(media, "date", None) or time.time()
).strftime("%Y-%m-%d_%H-%M-%S"),
datetime.fromtimestamp(data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)
temp_file_path = self.get_file(
dc_id=dc_id,
id=id,
access_hash=access_hash,
volume_id=volume_id,
local_id=local_id,
secret=secret,
size=size,
media_type=data.media_type,
dc_id=data.dc_id,
file_id=data.file_id,
access_hash=data.access_hash,
thumb_size=data.thumb_size,
peer_id=data.peer_id,
volume_id=data.volume_id,
local_id=data.local_id,
file_size=data.file_size,
is_big=data.is_big,
progress=progress,
progress_args=progress_args
)
@ -1549,16 +1521,21 @@ class Client(Methods, BaseClient):
finally:
session.stop()
def get_file(self,
def get_file(
self,
media_type: int,
dc_id: int,
id: int = None,
access_hash: int = None,
volume_id: int = None,
local_id: int = None,
secret: int = None,
size: int = None,
progress: callable = None,
progress_args: tuple = ()) -> str:
file_id: int,
access_hash: int,
thumb_size: str,
peer_id: int,
volume_id: int,
local_id: int,
file_size: int,
is_big: bool,
progress: callable,
progress_args: tuple = ()
) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
@ -1599,18 +1576,33 @@ class Client(Methods, BaseClient):
self.media_sessions[dc_id] = session
if volume_id: # Photos are accessed by volume_id, local_id, secret
location = types.InputFileLocation(
if media_type == 1:
location = types.InputPeerPhotoFileLocation(
peer=self.resolve_peer(peer_id),
volume_id=volume_id,
local_id=local_id,
secret=secret,
file_reference=b""
big=is_big or None
)
else: # Any other file can be more easily accessed by id and access_hash
location = types.InputDocumentFileLocation(
id=id,
elif media_type in (0, 2):
location = types.InputPhotoFileLocation(
id=file_id,
access_hash=access_hash,
file_reference=b""
file_reference=b"",
thumb_size=thumb_size
)
elif media_type == 14:
location = types.InputDocumentFileLocation(
id=file_id,
access_hash=access_hash,
file_reference=b"",
thumb_size=thumb_size
)
else:
location = types.InputDocumentFileLocation(
id=file_id,
access_hash=access_hash,
file_reference=b"",
thumb_size=""
)
limit = 1024 * 1024
@ -1641,7 +1633,14 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
progress(
self,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
r = session.send(
functions.upload.GetFile(
@ -1723,7 +1722,14 @@ class Client(Methods, BaseClient):
offset += limit
if progress:
progress(self, min(offset, size) if size != 0 else offset, size, *progress_args)
progress(
self,
min(offset, file_size)
if file_size != 0
else offset,
file_size,
*progress_args
)
if len(chunk) < limit:
break

View File

@ -19,6 +19,7 @@
import os
import platform
import re
from collections import namedtuple
from queue import Queue
from threading import Lock
@ -56,7 +57,7 @@ class BaseClient:
CONFIG_FILE = "./config.ini"
MEDIA_TYPE_ID = {
0: "thumbnail",
0: "photo_thumbnail",
1: "chat_photo",
2: "photo",
3: "voice",
@ -65,7 +66,8 @@ class BaseClient:
8: "sticker",
9: "audio",
10: "animation",
13: "video_note"
13: "video_note",
14: "document_thumbnail"
}
mime_types_to_extensions = {}
@ -82,6 +84,10 @@ class BaseClient:
mime_types_to_extensions[mime_type] = " ".join(extensions)
fields = ("media_type", "dc_id", "file_id", "access_hash", "thumb_size", "peer_id", "volume_id", "local_id",
"is_big", "file_size", "mime_type", "file_name", "date")
FileData = namedtuple("FileData", fields, defaults=(None,) * len(fields))
def __init__(self):
self.is_bot = None
self.dc_id = None

View File

@ -16,11 +16,14 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import binascii
import struct
from threading import Event
from typing import Union
import pyrogram
from pyrogram.client.ext import BaseClient
from pyrogram.client.ext import BaseClient, utils
from pyrogram.errors import FileIdInvalid
class DownloadMedia(BaseClient):
@ -81,67 +84,91 @@ class DownloadMedia(BaseClient):
``ValueError`` if the message doesn't contain any downloadable media
"""
error_message = "This message doesn't contain any downloadable media"
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note")
file_size = None
mime_type = None
date = None
if isinstance(message, pyrogram.Message):
if message.photo:
media = pyrogram.Document(
file_id=message.photo.sizes[-1].file_id,
file_size=message.photo.sizes[-1].file_size,
mime_type="",
date=message.photo.date,
client=self
)
elif message.audio:
media = message.audio
elif message.document:
media = message.document
elif message.video:
media = message.video
elif message.voice:
media = message.voice
elif message.video_note:
media = message.video_note
elif message.sticker:
media = message.sticker
elif message.animation:
media = message.animation
for kind in available_media:
media = getattr(message, kind, None)
if media is not None:
break
else:
raise ValueError(error_message)
elif isinstance(message, (
pyrogram.Photo,
pyrogram.PhotoSize,
pyrogram.Audio,
pyrogram.Document,
pyrogram.Video,
pyrogram.Voice,
pyrogram.VideoNote,
pyrogram.Sticker,
pyrogram.Animation
)):
if isinstance(message, pyrogram.Photo):
media = pyrogram.Document(
file_id=message.sizes[-1].file_id,
file_size=message.sizes[-1].file_size,
mime_type="",
date=message.date,
client=self
)
else:
media = message
elif isinstance(message, str):
media = pyrogram.Document(
file_id=message,
file_size=0,
mime_type="",
client=self
if isinstance(media, str):
file_id_str = media
else:
file_id_str = media.file_id
file_name = getattr(media, "file_name", "")
file_size = getattr(media, "file_size", None)
mime_type = getattr(media, "mime_type", None)
date = getattr(media, "date", None)
data = self.FileData(
file_name=file_name,
file_size=file_size,
mime_type=mime_type,
date=date
)
def get_existing_attributes() -> dict:
return dict(filter(lambda x: x[1] is not None, data._asdict().items()))
try:
decoded = utils.decode(file_id_str)
media_type = decoded[0]
if media_type == 1:
unpacked = struct.unpack("<iiqqib", decoded)
dc_id, peer_id, volume_id, local_id, is_big = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
peer_id=peer_id,
volume_id=volume_id,
local_id=local_id,
is_big=bool(is_big)
)
elif media_type in (0, 2, 14):
unpacked = struct.unpack("<iiqqc", decoded)
dc_id, file_id, access_hash, thumb_size = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
file_id=file_id,
access_hash=access_hash,
thumb_size=thumb_size.decode()
)
elif media_type in (3, 4, 5, 8, 9, 10, 13):
unpacked = struct.unpack("<iiqq", decoded)
dc_id, file_id, access_hash = unpacked[1:]
data = self.FileData(
**get_existing_attributes(),
media_type=media_type,
dc_id=dc_id,
file_id=file_id,
access_hash=access_hash
)
else:
raise ValueError(error_message)
raise ValueError("Unknown media type: {}".format(file_id_str))
except (AssertionError, binascii.Error, struct.error):
raise FileIdInvalid from None
done = Event()
path = [None]
self.download_queue.put((media, file_name, done, progress, progress_args, path))
self.download_queue.put((data, file_name, done, progress, progress_args, path))
if block:
done.wait()