diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 765de03b..9ced004f 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -34,6 +34,11 @@ from hashlib import sha256, md5 from queue import Queue from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Event, Thread +import tempfile + +import shutil + +import errno from pyrogram.api import functions, types from pyrogram.api.core import Object @@ -55,8 +60,6 @@ from pyrogram.session.internals import MsgId from .input_media import InputMedia from .style import Markdown, HTML -from typing import Any - log = logging.getLogger(__name__) ApiKey = namedtuple("ApiKey", ["api_id", "api_hash"]) @@ -510,14 +513,18 @@ class Client: break try: - media, file_dir, file_name, done, progress, path = media + media, file_name, done, progress, path = media + tmp_file_name = None - if file_dir is not None: - # Make file_dir if it was specified - os.makedirs(file_dir, exist_ok=True) + download_directory = "downloads" - if isinstance(file_name, str) and file_name is not None: - os.makedirs(os.path.dirname(file_name), exist_ok=True) + if file_name.endswith('/') or file_name.endswith('\\'): + # treat the file name as a directory + download_directory = file_name + file_name = None + elif '/' in file_name or '\\' in file_name: + # use file_name as a full path instead + download_directory = '' if isinstance(media, types.MessageMediaDocument): document = media.document @@ -543,16 +550,13 @@ class Client: elif isinstance(i, types.DocumentAttributeAnimated): file_name = file_name.replace("doc", "gif") - file_name = os.path.join(file_dir if file_dir is not None else '', file_name) - - self.get_file( + tmp_file_name = self.get_file( dc_id=document.dc_id, id=document.id, access_hash=document.access_hash, version=document.version, size=document.size, - progress=progress, - file_out=file_name + progress=progress ) elif isinstance(media, (types.MessageMediaPhoto, types.Photo)): if isinstance(media, types.MessageMediaPhoto): @@ -567,27 +571,46 @@ class Client: self.rnd_id() ) - file_name = os.path.join(file_dir if file_dir is not None else '', file_name) - photo_loc = photo.sizes[-1].location - self.get_file( + tmp_file_name = self.get_file( dc_id=photo_loc.dc_id, volume_id=photo_loc.volume_id, local_id=photo_loc.local_id, secret=photo_loc.secret, size=photo.sizes[-1].size, - progress=progress, - file_out=file_name + progress=progress ) if file_name is not None: - path[0] = file_name + path[0] = os.path.join(download_directory, file_name) + + try: + os.remove(os.path.join(download_directory, file_name)) + except OSError: + pass + finally: + try: + if download_directory: + os.makedirs(download_directory, exist_ok=True) + else: + os.makedirs(os.path.dirname(file_name), exist_ok=True) + + # avoid errors moving between drives on windows + shutil.move(tmp_file_name, os.path.join(download_directory, file_name)) + except OSError as e: + log.error(e, exc_info=True) except Exception as e: log.error(e, exc_info=True) finally: done.set() + try: + os.remove(tmp_file_name) + except OSError as e: + if not e.errno == errno.ENOENT: + log.error(e, exc_info=True) + log.debug("{} stopped".format(name)) def updates_worker(self): @@ -2176,9 +2199,7 @@ class Client: secret: int = None, version: int = 0, size: int = None, - progress: callable = None, - file_out: Any = None) -> str: - + progress: callable = None) -> str: if dc_id != self.dc_id: exported_auth = self.send( functions.auth.ExportAuthorization( @@ -2226,13 +2247,11 @@ class Client: version=version ) + fd, file_name = tempfile.mkstemp() + limit = 1024 * 1024 offset = 0 - # file object being written - f = None - close_file, call_flush, call_fsync = False, False, False - try: r = session.send( functions.upload.GetFile( @@ -2242,51 +2261,30 @@ class Client: ) ) - if file_out is None: - f = open("download_{}.temp".format(MsgId(), 'wb')) - close_file = True - - elif isinstance(file_out, str): - f = open(file_out, 'wb') - close_file = True - - elif hasattr(file_out, 'write'): - f = file_out - - if hasattr(file_out, 'flush'): - call_flush = True - if hasattr(file_out, 'fileno'): - call_fsync = True - else: - raise ValueError('file_out argument of client.get_file must at least implement a write method if not a ' - 'string.') - if isinstance(r, types.upload.File): - while True: - chunk = r.bytes + with os.fdopen(fd, "wb") as f: + while True: + chunk = r.bytes - if not chunk: - break + if not chunk: + break - f.write(chunk) - - if call_flush: + f.write(chunk) f.flush() - if call_fsync: os.fsync(f.fileno()) - offset += limit + offset += limit - if progress: - progress(min(offset, size), size) + if progress: + progress(min(offset, size), size) - r = session.send( - functions.upload.GetFile( - location=location, - offset=offset, - limit=limit + r = session.send( + functions.upload.GetFile( + location=location, + offset=offset, + limit=limit + ) ) - ) if isinstance(r, types.upload.FileCdnRedirect): cdn_session = Session( @@ -2301,76 +2299,77 @@ class Client: cdn_session.start() try: - while True: - r2 = cdn_session.send( - functions.upload.GetCdnFile( - location=location, - file_token=r.file_token, - offset=offset, - limit=limit - ) - ) - - if isinstance(r2, types.upload.CdnFileReuploadNeeded): - try: - session.send( - functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token - ) + with os.fdopen(fd, "wb") as f: + while True: + r2 = cdn_session.send( + functions.upload.GetCdnFile( + location=location, + file_token=r.file_token, + offset=offset, + limit=limit ) - except VolumeLocNotFound: - break - else: - continue + ) - chunk = r2.bytes + if isinstance(r2, types.upload.CdnFileReuploadNeeded): + try: + session.send( + functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token + ) + ) + except VolumeLocNotFound: + break + else: + continue - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = AES.ctr_decrypt( - chunk, - r.encryption_key, - r.encryption_iv, - offset - ) + chunk = r2.bytes - hashes = session.send( - functions.upload.GetCdnFileHashes( - r.file_token, + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = AES.ctr_decrypt( + chunk, + r.encryption_key, + r.encryption_iv, offset ) - ) - # https://core.telegram.org/cdn#verifying-files - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) + hashes = session.send( + functions.upload.GetCdnFileHashes( + r.file_token, + offset + ) + ) - f.write(decrypted_chunk) + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + assert h.hash == sha256(cdn_chunk).digest(), "Invalid CDN hash part {}".format(i) - if call_flush: + f.write(decrypted_chunk) f.flush() - if call_fsync: os.fsync(f.fileno()) - offset += limit + offset += limit - if progress: - progress(min(offset, size), size) + if progress: + progress(min(offset, size), size) - if len(chunk) < limit: - break + if len(chunk) < limit: + break except Exception as e: - log.error(e) + raise e finally: cdn_session.stop() except Exception as e: - log.error(e) + log.error(e, exc_info=True) + + try: + os.remove(file_name) + except OSError: + pass else: - return file_out + return file_name finally: - if close_file and f is not None: - f.close() session.stop() def join_chat(self, chat_id: str): @@ -2627,27 +2626,18 @@ class Client: def download_media(self, message: types.Message, file_name: str = None, - file_dir: str = None, block: bool = True, - progress: callable = None - ): + progress: callable = None): """Use this method to download the media from a Message. + Files are saved in the *downloads* folder. + Args: message (:obj:`Message `): The Message containing the media. file_name (:obj:`str`, optional): Specify a custom *file_name* to be used instead of the one provided by Telegram. - This parameter is expected to be a full file path to the location you want the - file to be placed, or a file like object. If not specified, the file will - be put into the directory specified by *file_dir* with a generated name. - - file_dir (:obj:`str`, optional): - Specify a directory to place the file in if no *file_name* is specified. - If *file_dir* is *None*, the current working directory is used. The default - value is the "downloads" folder in the current working directory. The - directory tree will be created if it does not exist. block (:obj:`bool`, optional): Blocks the code execution until the file has been downloaded. @@ -2669,15 +2659,7 @@ class Client: Raises: :class:`pyrogram.Error` - :class:`ValueError` if both file_name and file_dir are specified. """ - - if file_name is not None and file_dir is not None: - raise ValueError('file_name and file_dir may not be specified together.') - - if file_name is None and file_dir is None: - file_dir = 'downloads' - if isinstance(message, (types.Message, types.Photo)): done = Event() path = [None] @@ -2688,7 +2670,7 @@ class Client: media = message if media is not None: - self.download_queue.put((media, file_dir, file_name, done, progress, path)) + self.download_queue.put((media, file_name, done, progress, path)) else: return @@ -2700,7 +2682,6 @@ class Client: def download_photo(self, photo: types.Photo or types.UserProfilePhoto or types.ChatPhoto, file_name: str = None, - file_dir: str = None, block: bool = True): """Use this method to download a photo not contained inside a Message. For example, a photo of a User or a Chat/Channel. @@ -2712,16 +2693,7 @@ class Client: The photo object. file_name (:obj:`str`, optional): - Specify a custom *file_name* to be used instead of the one provided by Telegram. - This parameter is expected to be a full file path to the location you want the - photo to be placed, or a file like object. If not specified, the photo will - be put into the directory specified by *file_dir* with a generated name. - - file_dir (:obj:`str`, optional): - Specify a directory to place the photo in if no *file_name* is specified. - If *file_dir* is *None*, the current working directory is used. The default - value is the "downloads" folder in the current working directory. The - directory tree will be created if it does not exist. + Specify a custom *file_name* to be used. block (:obj:`bool`, optional): Blocks the code execution until the photo has been downloaded. @@ -2747,7 +2719,7 @@ class Client: )] ) - return self.download_media(photo, file_name, file_dir, block) + return self.download_media(photo, file_name, block) def add_contacts(self, contacts: list): """Use this method to add contacts to your Telegram address book.