2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-22 18:07:21 +00:00

Limit the amount of concurrent transmissions

This commit is contained in:
Dan 2023-01-08 17:11:02 +01:00
parent b19764d5dc
commit 8441ce2f47
2 changed files with 302 additions and 315 deletions

View File

@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import ( from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
AuthBytesInvalid, BadRequest BadRequest
) )
from pyrogram.handlers.handler import Handler from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods from pyrogram.methods import Methods
@ -266,6 +266,9 @@ class Client(Methods):
self.media_sessions = {} self.media_sessions = {}
self.media_sessions_lock = asyncio.Lock() self.media_sessions_lock = asyncio.Lock()
self.save_file_lock = asyncio.Lock()
self.get_file_lock = asyncio.Lock()
self.is_connected = None self.is_connected = None
self.is_initialized = None self.is_initialized = None
@ -795,49 +798,7 @@ class Client(Methods):
progress: Callable = None, progress: Callable = None,
progress_args: tuple = () progress_args: tuple = ()
) -> Optional[AsyncGenerator[bytes, None]]: ) -> Optional[AsyncGenerator[bytes, None]]:
dc_id = file_id.dc_id async with self.get_file_lock:
async with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != await self.storage.dc_id():
session = Session(
self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True
)
await session.start()
for _ in range(3):
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
try:
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
except AuthBytesInvalid:
continue
else:
break
else:
await session.stop()
raise AuthBytesInvalid
else:
session = Session(
self, dc_id, await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
)
await session.start()
self.media_sessions[dc_id] = session
file_type = file_id.file_type file_type = file_id.file_type
if file_type == FileType.CHAT_PHOTO: if file_type == FileType.CHAT_PHOTO:
@ -882,7 +843,34 @@ class Client(Methods):
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
offset_bytes = abs(offset) * chunk_size offset_bytes = abs(offset) * chunk_size
dc_id = file_id.dc_id
session = Session(
self, dc_id,
await Auth(self, dc_id, await self.storage.test_mode()).create()
if dc_id != await self.storage.dc_id()
else await self.storage.auth_key(),
await self.storage.test_mode(),
is_media=True
)
try: try:
await session.start()
if dc_id != await self.storage.dc_id():
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
r = await session.invoke( r = await session.invoke(
raw.functions.upload.GetFile( raw.functions.upload.GetFile(
location=location, location=location,
@ -929,20 +917,14 @@ class Client(Methods):
) )
elif isinstance(r, raw.types.upload.FileCdnRedirect): elif isinstance(r, raw.types.upload.FileCdnRedirect):
async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
cdn_session = Session( cdn_session = Session(
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(), self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True, is_cdn=True await self.storage.test_mode(), is_media=True, is_cdn=True
) )
try:
await cdn_session.start() await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
try:
while True: while True:
r2 = await cdn_session.invoke( r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile( raw.functions.upload.GetCdnFile(
@ -1014,10 +996,14 @@ class Client(Methods):
break break
except Exception as e: except Exception as e:
raise e raise e
finally:
await cdn_session.stop()
except pyrogram.StopTransmission: except pyrogram.StopTransmission:
raise raise
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
finally:
await session.stop()
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0] return self.mimetypes.guess_type(filename)[0]

View File

@ -94,6 +94,7 @@ class SaveFile:
Raises: Raises:
RPCError: In case of a Telegram RPC error. RPCError: In case of a Telegram RPC error.
""" """
async with self.save_file_lock:
if path is None: if path is None:
return None return None