diff --git a/compiler/docs/compiler.py b/compiler/docs/compiler.py index 9d1b7434..b5382438 100644 --- a/compiler/docs/compiler.py +++ b/compiler/docs/compiler.py @@ -187,6 +187,7 @@ def pyrogram_api(): search_global search_global_count download_media + stream_media get_discussion_message get_discussion_replies get_discussion_replies_count diff --git a/pyrogram/client.py b/pyrogram/client.py index f727658d..18e5253b 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -32,7 +32,7 @@ from importlib import import_module from io import StringIO, BytesIO from mimetypes import MimeTypes from pathlib import Path -from typing import Union, List, Optional, Callable, BinaryIO +from typing import Union, List, Optional, Callable, AsyncGenerator import pyrogram from pyrogram import __version__, __license__ @@ -722,13 +722,10 @@ class Client(Methods): async def handle_download(self, packet): file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet - file = await self.get_file( - file_id=file_id, - file_size=file_size, - in_memory=in_memory, - progress=progress, - progress_args=progress_args - ) + file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False) + + async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args): + file.write(chunk) if file and not in_memory: file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) @@ -749,11 +746,12 @@ class Client(Methods): async def get_file( self, file_id: FileId, - file_size: int, - in_memory: bool, - progress: Callable, + file_size: int = 0, + limit: int = 0, + offset: int = 0, + progress: Callable = None, progress_args: tuple = () - ) -> Optional[BinaryIO]: + ) -> Optional[AsyncGenerator[bytes, None]]: dc_id = file_id.dc_id async with self.media_sessions_lock: @@ -836,17 +834,17 @@ class Client(Methods): thumb_size=file_id.thumbnail_size ) - limit = 1024 * 1024 - offset = 0 - - file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb") + current = 0 + total = abs(limit) or (1 << 31) - 1 + chunk_size = 1024 * 1024 + offset_bytes = abs(offset) * chunk_size try: r = await session.invoke( raw.functions.upload.GetFile( location=location, - offset=offset, - limit=limit + offset=offset_bytes, + limit=chunk_size ), sleep_threshold=30 ) @@ -855,16 +853,17 @@ class Client(Methods): while True: chunk = r.bytes - file.write(chunk) + yield chunk - offset += limit + current += 1 + offset_bytes += chunk_size if progress: func = functools.partial( progress, - min(offset, file_size) + min(offset_bytes, file_size) if file_size != 0 - else offset, + else offset_bytes, file_size, *progress_args ) @@ -874,14 +873,14 @@ class Client(Methods): else: await self.loop.run_in_executor(self.executor, func) - if len(chunk) < limit: + if len(chunk) < chunk_size or current >= total: break r = await session.invoke( raw.functions.upload.GetFile( location=location, - offset=offset, - limit=limit + offset=offset_bytes, + limit=chunk_size ), sleep_threshold=30 ) @@ -905,8 +904,8 @@ class Client(Methods): r2 = await cdn_session.invoke( raw.functions.upload.GetCdnFile( file_token=r.file_token, - offset=offset, - limit=limit + offset=offset_bytes, + limit=chunk_size ) ) @@ -931,14 +930,14 @@ class Client(Methods): r.encryption_key, bytearray( r.encryption_iv[:-4] - + (offset // 16).to_bytes(4, "big") + + (offset_bytes // 16).to_bytes(4, "big") ) ) hashes = await session.invoke( raw.functions.upload.GetCdnFileHashes( file_token=r.file_token, - offset=offset + offset=offset_bytes ) ) @@ -947,14 +946,15 @@ class Client(Methods): cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) - file.write(decrypted_chunk) + yield decrypted_chunk - offset += limit + current += 1 + offset_bytes += chunk_size if progress: func = functools.partial( progress, - min(offset, file_size) if file_size != 0 else offset, + min(offset_bytes, file_size) if file_size != 0 else offset_bytes, file_size, *progress_args ) @@ -964,7 +964,7 @@ class Client(Methods): else: await self.loop.run_in_executor(self.executor, func) - if len(chunk) < limit: + if len(chunk) < chunk_size or current >= total: break except Exception as e: raise e @@ -972,12 +972,6 @@ class Client(Methods): if not isinstance(e, pyrogram.StopTransmission): log.error(e, exc_info=True) - file.close() - - return None - else: - return file - def guess_mime_type(self, filename: str) -> Optional[str]: return self.mimetypes.guess_type(filename)[0] diff --git a/pyrogram/methods/messages/__init__.py b/pyrogram/methods/messages/__init__.py index 0bf34900..dafce11e 100644 --- a/pyrogram/methods/messages/__init__.py +++ b/pyrogram/methods/messages/__init__.py @@ -61,6 +61,7 @@ from .send_video import SendVideo from .send_video_note import SendVideoNote from .send_voice import SendVoice from .stop_poll import StopPoll +from .stream_media import StreamMedia from .vote_poll import VotePoll @@ -110,6 +111,7 @@ class Messages( GetDiscussionMessage, SendReaction, GetDiscussionReplies, - GetDiscussionRepliesCount + GetDiscussionRepliesCount, + StreamMedia ): pass diff --git a/pyrogram/methods/messages/stream_media.py b/pyrogram/methods/messages/stream_media.py new file mode 100644 index 00000000..0daaa556 --- /dev/null +++ b/pyrogram/methods/messages/stream_media.py @@ -0,0 +1,93 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-present Dan +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +from typing import Union, Optional, BinaryIO + +import pyrogram +from pyrogram import types +from pyrogram.file_id import FileId + + +class StreamMedia: + async def stream_media( + self: "pyrogram.Client", + message: Union["types.Message", str], + limit: int = 0, + offset: int = 0 + ) -> Optional[Union[str, BinaryIO]]: + """Stream the media from a message chunk by chunk. + + The chunk size is 1 MiB (1024 * 1024 bytes). + + Parameters: + message (:obj:`~pyrogram.types.Message` | ``str``): + Pass a Message containing the media, the media itself (message.audio, message.video, ...) or a file id + as string. + + limit (``int``, *optional*): + Limit the amount of chunks to stream. + Defaults to 0 (stream the whole media). + + offset (``int``, *optional*): + How many chunks to skip before starting to stream. + Defaults to 0 (start from the beginning). + + Returns: + ``Generator``: A generator yielding bytes chunk by chunk + + Example: + .. code-block:: python + + # Stream the whole media + async for chunk in app.stream_media(message): + print(len(chunk)) + + # Stream the first 3 chunks only + async for chunk in app.stream_media(message, limit=3): + print(len(chunk)) + + # Stream the last 3 chunks only + import math + chunks = math.ceil(message.document.file_size / 1024 / 1024) + async for chunk in app.stream_media(message, offset=chunks - 3): + print(len(chunk)) + """ + available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note", + "new_chat_photo") + + if isinstance(message, types.Message): + for kind in available_media: + media = getattr(message, kind, None) + + if media is not None: + break + else: + raise ValueError("This message doesn't contain any downloadable media") + else: + media = message + + if isinstance(media, str): + file_id_str = media + else: + file_id_str = media.file_id + + file_id_obj = FileId.decode(file_id_str) + file_size = getattr(media, "file_size", 0) + + async for chunk in self.get_file(file_id_obj, file_size, limit, offset): + yield chunk