From 19b1bbb94297a54ece9bb310ea274796e9121e8a Mon Sep 17 00:00:00 2001 From: Eric Blundell Date: Tue, 20 Mar 2018 07:04:35 -0500 Subject: [PATCH] Allow download_media to download media to anywhere Remove the use of a temporary file in the programs working directory. --- pyrogram/client/client.py | 180 ++++++++++++++++++++------------------ 1 file changed, 97 insertions(+), 83 deletions(-) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 20ac58bc..a20ba521 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -55,6 +55,8 @@ from pyrogram.session.internals import MsgId from .input_media import InputMedia from .style import Markdown, HTML +from typing import Any + log = logging.getLogger(__name__) ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"]) @@ -509,7 +511,6 @@ class Client: try: media, file_name, done, progress, path = media - tmp_file_name = None if isinstance(media, types.MessageMediaDocument): document = media.document @@ -535,13 +536,14 @@ class Client: elif isinstance(i, types.DocumentAttributeAnimated): file_name = file_name.replace("doc", "gif") - tmp_file_name = self.get_file( + self.get_file( dc_id=document.dc_id, id=document.id, access_hash=document.access_hash, version=document.version, size=document.size, - progress=progress + progress=progress, + file_out=file_name ) elif isinstance(media, (types.MessageMediaPhoto, types.Photo)): if isinstance(media, types.MessageMediaPhoto): @@ -558,37 +560,23 @@ class Client: photo_loc = photo.sizes[-1].location - tmp_file_name = self.get_file( + self.get_file( dc_id=photo_loc.dc_id, volume_id=photo_loc.volume_id, local_id=photo_loc.local_id, secret=photo_loc.secret, size=photo.sizes[-1].size, - progress=progress + progress=progress, + file_out=file_name ) if file_name is not None: - path[0] = "downloads/{}".format(file_name) - - try: - os.remove("downloads/{}".format(file_name)) - except OSError: - pass - finally: - try: - os.renames("{}".format(tmp_file_name), "downloads/{}".format(file_name)) - except OSError: - pass + path[0] = file_name except Exception as e: log.error(e, exc_info=True) finally: done.set() - try: - os.remove("{}".format(tmp_file_name)) - except OSError: - pass - log.debug("{} stopped".format(name)) def updates_worker(self): @@ -2177,7 +2165,9 @@ class Client: secret: int = None, version: int = 0, size: int = None, - progress: callable = None) -> str: + progress: callable = None, + file_out: Any = None) -> str: + if dc_id != self.dc_id: exported_auth = self.send( functions.auth.ExportAuthorization( @@ -2225,10 +2215,13 @@ class Client: version=version ) - file_name = "download_{}.temp".format(MsgId()) limit = 1024 * 1024 offset = 0 + # file object being written + f = None + close_file, call_flush, call_fsync = False, False, False + try: r = session.send( functions.upload.GetFile( @@ -2238,30 +2231,49 @@ class Client: ) ) + if file_out is None: + f = open("download_{}.temp".format(MsgId(), 'wb')) + close_file = True + + elif isinstance(file_out, str): + f = open(file_out, 'wb') + elif hasattr(file_out, 'write'): + f = file_out + + if hasattr(file_out, 'flush'): + call_flush = True + if hasattr(file_out, 'fileno'): + call_fsync = True + else: + raise ValueError('file_out argument of client.get_file must at least implement a write method if not a ' + 'string.') + if isinstance(r, types.upload.File): - with open(file_name, "wb") as f: - while True: - chunk = r.bytes + while True: + chunk = r.bytes - if not chunk: - break + if not chunk: + break - f.write(chunk) + f.write(chunk) + + if call_flush: f.flush() + if call_fsync: os.fsync(f.fileno()) - offset += limit + offset += limit - if progress: - progress(min(offset, size), size) + if progress: + progress(min(offset, size), size) - r = session.send( - functions.upload.GetFile( - location=location, - offset=offset, - limit=limit - ) + r = session.send( + functions.upload.GetFile( + location=location, + offset=offset, + limit=limit ) + ) if isinstance(r, types.upload.FileCdnRedirect): cdn_session = Session( @@ -2276,63 +2288,65 @@ class Client: cdn_session.start() try: - with open(file_name, "wb") as f: - while True: - r2 = cdn_session.send( - functions.upload.GetCdnFile( - location=location, - file_token=r.file_token, - offset=offset, - limit=limit - ) + while True: + r2 = cdn_session.send( + functions.upload.GetCdnFile( + location=location, + file_token=r.file_token, + offset=offset, + limit=limit ) + ) - if isinstance(r2, types.upload.CdnFileReuploadNeeded): - try: - session.send( - functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token - ) + if isinstance(r2, types.upload.CdnFileReuploadNeeded): + try: + session.send( + functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token ) - except VolumeLocNotFound: - break - else: - continue + ) + except VolumeLocNotFound: + break + else: + continue - chunk = r2.bytes + chunk = r2.bytes - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = AES.ctr_decrypt( - chunk, - r.encryption_key, - r.encryption_iv, + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = AES.ctr_decrypt( + chunk, + r.encryption_key, + r.encryption_iv, + offset + ) + + hashes = session.send( + functions.upload.GetCdnFileHashes( + r.file_token, offset ) + ) - hashes = session.send( - functions.upload.GetCdnFileHashes( - r.file_token, - offset - ) - ) + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) - # https://core.telegram.org/cdn#verifying-files - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) + f.write(decrypted_chunk) - f.write(decrypted_chunk) + if call_flush: f.flush() + if call_fsync: os.fsync(f.fileno()) - offset += limit + offset += limit - if progress: - progress(min(offset, size), size) + if progress: + progress(min(offset, size), size) - if len(chunk) < limit: - break + if len(chunk) < limit: + break except Exception as e: log.error(e) finally: @@ -2340,8 +2354,10 @@ class Client: except Exception as e: log.error(e) else: - return file_name + return file_out finally: + if close_file and f and hasattr(f, 'close'): + f.close() session.stop() def join_chat(self, chat_id: str): @@ -2602,8 +2618,6 @@ class Client: progress: callable = None): """Use this method to download the media from a Message. - Files are saved in the *downloads* folder. - Args: message (:obj:`Message `): The Message containing the media.