2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-22 09:57:19 +00:00

Limit the amount of concurrent transmissions

This commit is contained in:
Dan 2023-01-08 17:11:02 +01:00
parent b19764d5dc
commit 8441ce2f47
2 changed files with 302 additions and 315 deletions

View File

@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import (
SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate,
AuthBytesInvalid, BadRequest
BadRequest
)
from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods
@ -266,6 +266,9 @@ class Client(Methods):
self.media_sessions = {}
self.media_sessions_lock = asyncio.Lock()
self.save_file_lock = asyncio.Lock()
self.get_file_lock = asyncio.Lock()
self.is_connected = None
self.is_initialized = None
@ -795,49 +798,7 @@ class Client(Methods):
progress: Callable = None,
progress_args: tuple = ()
) -> Optional[AsyncGenerator[bytes, None]]:
dc_id = file_id.dc_id
async with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != await self.storage.dc_id():
session = Session(
self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True
)
await session.start()
for _ in range(3):
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
try:
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
except AuthBytesInvalid:
continue
else:
break
else:
await session.stop()
raise AuthBytesInvalid
else:
session = Session(
self, dc_id, await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
)
await session.start()
self.media_sessions[dc_id] = session
async with self.get_file_lock:
file_type = file_id.file_type
if file_type == FileType.CHAT_PHOTO:
@ -882,7 +843,34 @@ class Client(Methods):
chunk_size = 1024 * 1024
offset_bytes = abs(offset) * chunk_size
dc_id = file_id.dc_id
session = Session(
self, dc_id,
await Auth(self, dc_id, await self.storage.test_mode()).create()
if dc_id != await self.storage.dc_id()
else await self.storage.auth_key(),
await self.storage.test_mode(),
is_media=True
)
try:
await session.start()
if dc_id != await self.storage.dc_id():
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
@ -929,20 +917,14 @@ class Client(Methods):
)
elif isinstance(r, raw.types.upload.FileCdnRedirect):
async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
cdn_session = Session(
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True, is_cdn=True
)
try:
await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
try:
while True:
r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile(
@ -1014,10 +996,14 @@ class Client(Methods):
break
except Exception as e:
raise e
finally:
await cdn_session.stop()
except pyrogram.StopTransmission:
raise
except Exception as e:
log.exception(e)
finally:
await session.stop()
def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0]

View File

@ -94,6 +94,7 @@ class SaveFile:
Raises:
RPCError: In case of a Telegram RPC error.
"""
async with self.save_file_lock:
if path is None:
return None