2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-23 10:28:00 +00:00

Add support for media streams with the method stream_media

This commit is contained in:
Dan 2022-04-24 11:56:07 +02:00
parent b2c4d26ce6
commit 3e33ef0c0d
4 changed files with 130 additions and 40 deletions

View File

@ -187,6 +187,7 @@ def pyrogram_api():
search_global search_global
search_global_count search_global_count
download_media download_media
stream_media
get_discussion_message get_discussion_message
get_discussion_replies get_discussion_replies
get_discussion_replies_count get_discussion_replies_count

View File

@ -32,7 +32,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, BinaryIO from typing import Union, List, Optional, Callable, AsyncGenerator
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -722,13 +722,10 @@ class Client(Methods):
async def handle_download(self, packet): async def handle_download(self, packet):
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
file = await self.get_file( file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False)
file_id=file_id,
file_size=file_size, async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args):
in_memory=in_memory, file.write(chunk)
progress=progress,
progress_args=progress_args
)
if file and not in_memory: if file and not in_memory:
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
@ -749,11 +746,12 @@ class Client(Methods):
async def get_file( async def get_file(
self, self,
file_id: FileId, file_id: FileId,
file_size: int, file_size: int = 0,
in_memory: bool, limit: int = 0,
progress: Callable, offset: int = 0,
progress: Callable = None,
progress_args: tuple = () progress_args: tuple = ()
) -> Optional[BinaryIO]: ) -> Optional[AsyncGenerator[bytes, None]]:
dc_id = file_id.dc_id dc_id = file_id.dc_id
async with self.media_sessions_lock: async with self.media_sessions_lock:
@ -836,17 +834,17 @@ class Client(Methods):
thumb_size=file_id.thumbnail_size thumb_size=file_id.thumbnail_size
) )
limit = 1024 * 1024 current = 0
offset = 0 total = abs(limit) or (1 << 31) - 1
chunk_size = 1024 * 1024
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb") offset_bytes = abs(offset) * chunk_size
try: try:
r = await session.invoke( r = await session.invoke(
raw.functions.upload.GetFile( raw.functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset_bytes,
limit=limit limit=chunk_size
), ),
sleep_threshold=30 sleep_threshold=30
) )
@ -855,16 +853,17 @@ class Client(Methods):
while True: while True:
chunk = r.bytes chunk = r.bytes
file.write(chunk) yield chunk
offset += limit current += 1
offset_bytes += chunk_size
if progress: if progress:
func = functools.partial( func = functools.partial(
progress, progress,
min(offset, file_size) min(offset_bytes, file_size)
if file_size != 0 if file_size != 0
else offset, else offset_bytes,
file_size, file_size,
*progress_args *progress_args
) )
@ -874,14 +873,14 @@ class Client(Methods):
else: else:
await self.loop.run_in_executor(self.executor, func) await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit: if len(chunk) < chunk_size or current >= total:
break break
r = await session.invoke( r = await session.invoke(
raw.functions.upload.GetFile( raw.functions.upload.GetFile(
location=location, location=location,
offset=offset, offset=offset_bytes,
limit=limit limit=chunk_size
), ),
sleep_threshold=30 sleep_threshold=30
) )
@ -905,8 +904,8 @@ class Client(Methods):
r2 = await cdn_session.invoke( r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile( raw.functions.upload.GetCdnFile(
file_token=r.file_token, file_token=r.file_token,
offset=offset, offset=offset_bytes,
limit=limit limit=chunk_size
) )
) )
@ -931,14 +930,14 @@ class Client(Methods):
r.encryption_key, r.encryption_key,
bytearray( bytearray(
r.encryption_iv[:-4] r.encryption_iv[:-4]
+ (offset // 16).to_bytes(4, "big") + (offset_bytes // 16).to_bytes(4, "big")
) )
) )
hashes = await session.invoke( hashes = await session.invoke(
raw.functions.upload.GetCdnFileHashes( raw.functions.upload.GetCdnFileHashes(
file_token=r.file_token, file_token=r.file_token,
offset=offset offset=offset_bytes
) )
) )
@ -947,14 +946,15 @@ class Client(Methods):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
file.write(decrypted_chunk) yield decrypted_chunk
offset += limit current += 1
offset_bytes += chunk_size
if progress: if progress:
func = functools.partial( func = functools.partial(
progress, progress,
min(offset, file_size) if file_size != 0 else offset, min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
file_size, file_size,
*progress_args *progress_args
) )
@ -964,7 +964,7 @@ class Client(Methods):
else: else:
await self.loop.run_in_executor(self.executor, func) await self.loop.run_in_executor(self.executor, func)
if len(chunk) < limit: if len(chunk) < chunk_size or current >= total:
break break
except Exception as e: except Exception as e:
raise e raise e
@ -972,12 +972,6 @@ class Client(Methods):
if not isinstance(e, pyrogram.StopTransmission): if not isinstance(e, pyrogram.StopTransmission):
log.error(e, exc_info=True) log.error(e, exc_info=True)
file.close()
return None
else:
return file
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0] return self.mimetypes.guess_type(filename)[0]

View File

@ -61,6 +61,7 @@ from .send_video import SendVideo
from .send_video_note import SendVideoNote from .send_video_note import SendVideoNote
from .send_voice import SendVoice from .send_voice import SendVoice
from .stop_poll import StopPoll from .stop_poll import StopPoll
from .stream_media import StreamMedia
from .vote_poll import VotePoll from .vote_poll import VotePoll
@ -110,6 +111,7 @@ class Messages(
GetDiscussionMessage, GetDiscussionMessage,
SendReaction, SendReaction,
GetDiscussionReplies, GetDiscussionReplies,
GetDiscussionRepliesCount GetDiscussionRepliesCount,
StreamMedia
): ):
pass pass

View File

@ -0,0 +1,93 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-present Dan <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Union, Optional, BinaryIO
import pyrogram
from pyrogram import types
from pyrogram.file_id import FileId
class StreamMedia:
async def stream_media(
self: "pyrogram.Client",
message: Union["types.Message", str],
limit: int = 0,
offset: int = 0
) -> Optional[Union[str, BinaryIO]]:
"""Stream the media from a message chunk by chunk.
The chunk size is 1 MiB (1024 * 1024 bytes).
Parameters:
message (:obj:`~pyrogram.types.Message` | ``str``):
Pass a Message containing the media, the media itself (message.audio, message.video, ...) or a file id
as string.
limit (``int``, *optional*):
Limit the amount of chunks to stream.
Defaults to 0 (stream the whole media).
offset (``int``, *optional*):
How many chunks to skip before starting to stream.
Defaults to 0 (start from the beginning).
Returns:
``Generator``: A generator yielding bytes chunk by chunk
Example:
.. code-block:: python
# Stream the whole media
async for chunk in app.stream_media(message):
print(len(chunk))
# Stream the first 3 chunks only
async for chunk in app.stream_media(message, limit=3):
print(len(chunk))
# Stream the last 3 chunks only
import math
chunks = math.ceil(message.document.file_size / 1024 / 1024)
async for chunk in app.stream_media(message, offset=chunks - 3):
print(len(chunk))
"""
available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note",
"new_chat_photo")
if isinstance(message, types.Message):
for kind in available_media:
media = getattr(message, kind, None)
if media is not None:
break
else:
raise ValueError("This message doesn't contain any downloadable media")
else:
media = message
if isinstance(media, str):
file_id_str = media
else:
file_id_str = media.file_id
file_id_obj = FileId.decode(file_id_str)
file_size = getattr(media, "file_size", 0)
async for chunk in self.get_file(file_id_obj, file_size, limit, offset):
yield chunk