From 15561d19d577206461bdffb871e3ee3a9914ab10 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Sun, 18 Feb 2018 18:11:33 +0100 Subject: [PATCH] Add initial support for downloading media --- pyrogram/client/client.py | 181 ++++++++++++++++++++++++++------------ 1 file changed, 124 insertions(+), 57 deletions(-) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 54ddf256..993b9bf9 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -27,6 +27,7 @@ import threading import time from collections import namedtuple from configparser import ConfigParser +from datetime import datetime from hashlib import sha256, md5 from queue import Queue from signal import signal, SIGINT, SIGTERM, SIGABRT @@ -39,8 +40,8 @@ from pyrogram.api.errors import ( PhoneNumberUnoccupied, PhoneCodeInvalid, PhoneCodeHashEmpty, PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PasswordHashInvalid, FloodWait, PeerIdInvalid, FilePartMissing, - ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned -) + ChatAdminRequired, FirstnameInvalid, PhoneNumberBanned, + VolumeLocNotFound) from pyrogram.api.types import ( User, Chat, Channel, PeerUser, PeerChannel, @@ -49,6 +50,7 @@ from pyrogram.api.types import ( ) from pyrogram.crypto import AES from pyrogram.session import Auth, Session +from pyrogram.session.internals import MsgId from .input_media import InputMedia from .style import Markdown, HTML @@ -103,6 +105,7 @@ class Client: INVITE_LINK_RE = re.compile(r"^(?:https?://)?t\.me/joinchat/(.+)$") DIALOGS_AT_ONCE = 100 UPDATES_WORKERS = 2 + DOWNLOAD_WORKERS = 1 def __init__(self, session_name: str, @@ -148,6 +151,8 @@ class Client: self.update_queue = Queue() self.update_handler = None + self.download_queue = Queue() + def start(self): """Use this method to start the Client after creating it. Requires no parameters. @@ -176,7 +181,7 @@ class Client: self.password = None self.save_session() - self.rnd_id = self.session.msg_id + self.rnd_id = MsgId self.get_dialogs() for i in range(self.UPDATES_WORKERS): @@ -185,6 +190,9 @@ class Client: for i in range(self.workers): Thread(target=self.update_worker, name="UpdateWorker#{}".format(i + 1)).start() + for i in range(self.DOWNLOAD_WORKERS): + Thread(target=self.download_worker, name="DownloadWorker#{}".format(i + 1)).start() + mimetypes.init() def stop(self): @@ -199,6 +207,9 @@ class Client: for _ in range(self.workers): self.update_queue.put(None) + for _ in range(self.DOWNLOAD_WORKERS): + self.download_queue.put(None) + def fetch_peers(self, entities: list): for entity in entities: if isinstance(entity, User): @@ -260,6 +271,67 @@ class Client: if username is not None: self.peers_by_username[username] = input_peer + def download_worker(self): + name = threading.current_thread().name + log.debug("{} started".format(name)) + + while True: + message = self.download_queue.get() + + if message is None: + break + + message, done = message + + try: + if isinstance(message.media, types.MessageMediaDocument): + document = message.media.document + + if isinstance(document, types.Document): + file_name = "doc_{}{}".format( + datetime.fromtimestamp(document.date).strftime("%Y-%m-%d_%H-%M-%S"), + mimetypes.guess_extension(document.mime_type) or ".unknown" + ) + + for i in document.attributes: + if isinstance(i, types.DocumentAttributeFilename): + file_name = i.file_name + break + elif isinstance(i, types.DocumentAttributeSticker): + file_name = file_name.replace("doc", "sticker") + elif isinstance(i, types.DocumentAttributeAudio): + file_name = file_name.replace("doc", "audio") + elif isinstance(i, types.DocumentAttributeVideo): + file_name = file_name.replace("doc", "video") + elif isinstance(i, types.DocumentAttributeAnimated): + file_name = file_name.replace("doc", "gif") + + tmp_file_name = self.get_file( + dc_id=document.dc_id, + id=document.id, + access_hash=document.access_hash, + version=document.version + ) + + i = 1 + while True: + try: + os.renames("./{}".format(tmp_file_name), "./downloads/{}".format( + ".".join(file_name.split(".")[:-1]) + + (" ({}).".format(i) if i > 1 else ".") + + file_name.split(".")[-1] + )) + except FileExistsError: + i += 1 + else: + break + + done.set() + except Exception as e: + log.error(e, exc_info=True) + + log.debug("{} stopped".format(name)) + def updates_worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) @@ -1667,8 +1739,7 @@ class Client: part_size = 512 * 1024 file_size = os.path.getsize(path) file_total_parts = math.ceil(file_size / part_size) - # is_big = True if file_size > 10 * 1024 * 1024 else False - is_big = False # Treat all files as not-big to have the server check for the md5 sum + is_big = True if file_size > 10 * 1024 * 1024 else False is_missing_part = True if file_id is not None else False file_id = file_id or self.rnd_id() md5_sum = md5() if not is_big and not is_missing_part else None @@ -1759,22 +1830,19 @@ class Client: session.start() if volume_id: # Photos are accessed by volume_id, local_id, secret - file_name = "_".join(str(i) for i in [dc_id, volume_id, local_id, secret]) - location = types.InputFileLocation( volume_id=volume_id, local_id=local_id, secret=secret ) else: # Any other file can be more easily accessed by id and access_hash - file_name = "_".join(str(i) for i in [dc_id, id, access_hash, version]) - location = types.InputDocumentFileLocation( id=id, access_hash=access_hash, version=version ) + file_name = str(MsgId()) limit = 1024 * 1024 offset = 0 @@ -1822,63 +1890,57 @@ class Client: cdn_session.start() try: - r2 = cdn_session.send( - functions.upload.GetCdnFile( - location=location, - file_token=r.file_token, - offset=offset, - limit=limit - ) - ) - - if isinstance(r2, types.upload.CdnFileReuploadNeeded): - session.send( - functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token + with open(file_name, "wb") as f: + while True: + r2 = cdn_session.send( + functions.upload.GetCdnFile( + location=location, + file_token=r.file_token, + offset=offset, + limit=limit + ) ) - ) - else: - with open(file_name, "wb") as f: - while True: - if not isinstance(r2, types.upload.CdnFile): + + if isinstance(r2, types.upload.CdnFileReuploadNeeded): + try: + session.send( + functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token + ) + ) + except VolumeLocNotFound: break + else: + continue - chunk = r2.bytes + chunk = r2.bytes - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = AES.ctr_decrypt( - chunk, - r.encryption_key, - r.encryption_iv, + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = AES.ctr_decrypt( + chunk, + r.encryption_key, + r.encryption_iv, + offset + ) + + hashes = session.send( + functions.upload.GetCdnFileHashes( + r.file_token, offset ) + ) - hashes = session.send( - functions.upload.GetCdnFileHashes( - r.file_token, - offset - ) - ) + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) + f.write(decrypted_chunk) + f.flush() + os.fsync(f.fileno()) - f.write(decrypted_chunk) - f.flush() - os.fsync(f.fileno()) - - offset += limit - - r2 = cdn_session.send( - functions.upload.GetCdnFile( - location=location, - file_token=r.file_token, - offset=offset, - limit=limit - ) - ) + offset += limit except Exception as e: log.error(e) finally: @@ -2238,3 +2300,8 @@ class Client: reply_to_msg_id=reply_to_message_id ) ) + + def download_media(self, message: types.Message): + done = Event() + self.download_queue.put((message, done)) + done.wait()