diff --git a/pyrogram/client.py b/pyrogram/client.py index 79168a93..f727658d 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -29,10 +29,10 @@ import tempfile from concurrent.futures.thread import ThreadPoolExecutor from hashlib import sha256 from importlib import import_module -from io import StringIO +from io import StringIO, BytesIO from mimetypes import MimeTypes from pathlib import Path -from typing import Union, List, Optional, Callable +from typing import Union, List, Optional, Callable, BinaryIO import pyrogram from pyrogram import __version__, __license__ @@ -482,34 +482,6 @@ class Client(Methods): return is_min - async def handle_download(self, packet): - temp_file_path = "" - final_file_path = "" - - try: - file_id, directory, file_name, file_size, progress, progress_args = packet - - temp_file_path = await self.get_file( - file_id=file_id, - file_size=file_size, - progress=progress, - progress_args=progress_args - ) - - if temp_file_path: - final_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) - os.makedirs(directory, exist_ok=True) - shutil.move(temp_file_path, final_file_path) - except Exception as e: - log.error(e, exc_info=True) - - try: - os.remove(temp_file_path) - except OSError: - pass - else: - return final_file_path or None - async def handle_updates(self, updates): if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)): is_min = (await self.fetch_peers(updates.users)) or (await self.fetch_peers(updates.chats)) @@ -747,13 +719,41 @@ class Client(Methods): else: log.warning(f'[{self.session_name}] No plugin loaded from "{root}"') + async def handle_download(self, packet): + file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet + + file = await self.get_file( + file_id=file_id, + file_size=file_size, + in_memory=in_memory, + progress=progress, + progress_args=progress_args + ) + + if file and not in_memory: + file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) + os.makedirs(directory, exist_ok=True) + shutil.move(file.name, file_path) + + try: + file.close() + except FileNotFoundError: + pass + + return file_path + + if file and in_memory: + file.name = file_name + return file + async def get_file( self, file_id: FileId, file_size: int, + in_memory: bool, progress: Callable, progress_args: tuple = () - ) -> str: + ) -> Optional[BinaryIO]: dc_id = file_id.dc_id async with self.media_sessions_lock: @@ -838,7 +838,8 @@ class Client(Methods): limit = 1024 * 1024 offset = 0 - file_name = "" + + file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb") try: r = await session.invoke( @@ -851,43 +852,40 @@ class Client(Methods): ) if isinstance(r, raw.types.upload.File): - with tempfile.NamedTemporaryFile("wb", delete=False) as f: - file_name = f.name + while True: + chunk = r.bytes - while True: - chunk = r.bytes + file.write(chunk) - f.write(chunk) + offset += limit - offset += limit - - if progress: - func = functools.partial( - progress, - min(offset, file_size) - if file_size != 0 - else offset, - file_size, - *progress_args - ) - - if inspect.iscoroutinefunction(progress): - await func() - else: - await self.loop.run_in_executor(self.executor, func) - - if len(chunk) < limit: - break - - r = await session.invoke( - raw.functions.upload.GetFile( - location=location, - offset=offset, - limit=limit - ), - sleep_threshold=30 + if progress: + func = functools.partial( + progress, + min(offset, file_size) + if file_size != 0 + else offset, + file_size, + *progress_args ) + if inspect.iscoroutinefunction(progress): + await func() + else: + await self.loop.run_in_executor(self.executor, func) + + if len(chunk) < limit: + break + + r = await session.invoke( + raw.functions.upload.GetFile( + location=location, + offset=offset, + limit=limit + ), + sleep_threshold=30 + ) + elif isinstance(r, raw.types.upload.FileCdnRedirect): async with self.media_sessions_lock: cdn_session = self.media_sessions.get(r.dc_id, None) @@ -903,88 +901,82 @@ class Client(Methods): self.media_sessions[r.dc_id] = cdn_session try: - with tempfile.NamedTemporaryFile("wb", delete=False) as f: - file_name = f.name - - while True: - r2 = await cdn_session.invoke( - raw.functions.upload.GetCdnFile( - file_token=r.file_token, - offset=offset, - limit=limit - ) + while True: + r2 = await cdn_session.invoke( + raw.functions.upload.GetCdnFile( + file_token=r.file_token, + offset=offset, + limit=limit ) + ) - if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded): - try: - await session.invoke( - raw.functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token - ) + if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded): + try: + await session.invoke( + raw.functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token ) - except VolumeLocNotFound: - break - else: - continue - - chunk = r2.bytes - - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = aes.ctr256_decrypt( - chunk, - r.encryption_key, - bytearray( - r.encryption_iv[:-4] - + (offset // 16).to_bytes(4, "big") ) - ) - - hashes = await session.invoke( - raw.functions.upload.GetCdnFileHashes( - file_token=r.file_token, - offset=offset - ) - ) - - # https://core.telegram.org/cdn#verifying-files - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) - - f.write(decrypted_chunk) - - offset += limit - - if progress: - func = functools.partial( - progress, - min(offset, file_size) if file_size != 0 else offset, - file_size, - *progress_args - ) - - if inspect.iscoroutinefunction(progress): - await func() - else: - await self.loop.run_in_executor(self.executor, func) - - if len(chunk) < limit: + except VolumeLocNotFound: break + else: + continue + + chunk = r2.bytes + + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = aes.ctr256_decrypt( + chunk, + r.encryption_key, + bytearray( + r.encryption_iv[:-4] + + (offset // 16).to_bytes(4, "big") + ) + ) + + hashes = await session.invoke( + raw.functions.upload.GetCdnFileHashes( + file_token=r.file_token, + offset=offset + ) + ) + + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) + + file.write(decrypted_chunk) + + offset += limit + + if progress: + func = functools.partial( + progress, + min(offset, file_size) if file_size != 0 else offset, + file_size, + *progress_args + ) + + if inspect.iscoroutinefunction(progress): + await func() + else: + await self.loop.run_in_executor(self.executor, func) + + if len(chunk) < limit: + break except Exception as e: raise e except Exception as e: if not isinstance(e, pyrogram.StopTransmission): log.error(e, exc_info=True) - try: - os.remove(file_name) - except OSError: - pass + file.close() - return "" + return None else: - return file_name + return file def guess_mime_type(self, filename: str) -> Optional[str]: return self.mimetypes.guess_type(filename)[0] diff --git a/pyrogram/methods/messages/download_media.py b/pyrogram/methods/messages/download_media.py index 8b587d2f..d46bd503 100644 --- a/pyrogram/methods/messages/download_media.py +++ b/pyrogram/methods/messages/download_media.py @@ -18,9 +18,8 @@ import asyncio import os -import time from datetime import datetime -from typing import Union, Optional, Callable +from typing import Union, Optional, Callable, BinaryIO import pyrogram from pyrogram import types @@ -34,10 +33,11 @@ class DownloadMedia: self: "pyrogram.Client", message: Union["types.Message", str], file_name: str = DEFAULT_DOWNLOAD_DIR, + in_memory: bool = False, block: bool = True, progress: Callable = None, progress_args: tuple = () - ) -> Optional[str]: + ) -> Optional[Union[str, BinaryIO]]: """Download the media from a message. Parameters: @@ -51,6 +51,11 @@ class DownloadMedia: You can also specify a path for downloading files in a custom location: paths that end with "/" are considered directories. All non-existent folders will be created automatically. + in_memory (``bool``, *optional*): + Pass True to download the media in-memory. + A binary file-like object with its attribute ".name" set will be returned. + Defaults to False. + block (``bool``, *optional*): Blocks the code execution until the file has been downloaded. Defaults to True. @@ -78,14 +83,17 @@ class DownloadMedia: You can either keep ``*args`` or add every single extra argument in your function signature. Returns: - ``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case - the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is - returned. + ``str`` | ``None`` | ``BinaryIO``: On success, the absolute path of the downloaded file is returned, + otherwise, in case the download failed or was deliberately stopped with + :meth:`~pyrogram.Client.stop_transmission`, None is returned. + Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned. Raises: ValueError: if the message doesn't contain any downloadable media Example: + Download media to file + .. code-block:: python # Download from Message @@ -99,6 +107,15 @@ class DownloadMedia: print(f"{current * 100 / total:.1f}%") await app.download_media(message, progress=progress) + + Download media in-memory + + .. code-block:: python + + file = await app.download_media(message, in_memory=True) + + file_name = file.name + file_bytes = bytes(file.getbuffer()) """ available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note", "new_chat_photo") @@ -125,7 +142,7 @@ class DownloadMedia: media_file_name = getattr(media, "file_name", "") file_size = getattr(media, "file_size", 0) mime_type = getattr(media, "mime_type", "") - date = getattr(media, "date", 0) + date = getattr(media, "date", None) directory, file_name = os.path.split(file_name) file_name = file_name or media_file_name or "" @@ -153,12 +170,14 @@ class DownloadMedia: file_name = "{}_{}_{}{}".format( FileType(file_id_obj.file_type).name.lower(), - datetime.fromtimestamp(date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"), + (date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"), self.rnd_id(), extension ) - downloader = self.handle_download((file_id_obj, directory, file_name, file_size, progress, progress_args)) + downloader = self.handle_download( + (file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args) + ) if block: return await downloader diff --git a/pyrogram/types/messages_and_media/message.py b/pyrogram/types/messages_and_media/message.py index ec5eed51..b7b26805 100644 --- a/pyrogram/types/messages_and_media/message.py +++ b/pyrogram/types/messages_and_media/message.py @@ -3329,6 +3329,7 @@ class Message(Object, Update): async def download( self, file_name: str = "", + in_memory: bool = False, block: bool = True, progress: Callable = None, progress_args: tuple = () @@ -3353,6 +3354,11 @@ class Message(Object, Update): You can also specify a path for downloading files in a custom location: paths that end with "/" are considered directories. All non-existent folders will be created automatically. + in_memory (``bool``, *optional*): + Pass True to download the media in-memory. + A binary file-like object with its attribute ".name" set will be returned. + Defaults to False. + block (``bool``, *optional*): Blocks the code execution until the file has been downloaded. Defaults to True. @@ -3389,6 +3395,7 @@ class Message(Object, Update): return await self._client.download_media( message=self, file_name=file_name, + in_memory=in_memory, block=block, progress=progress, progress_args=progress_args,