mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-22 09:57:19 +00:00
Limit the amount of concurrent transmissions
This commit is contained in:
parent
b19764d5dc
commit
8441ce2f47
@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
|
||||
from pyrogram.errors import (
|
||||
SessionPasswordNeeded,
|
||||
VolumeLocNotFound, ChannelPrivate,
|
||||
AuthBytesInvalid, BadRequest
|
||||
BadRequest
|
||||
)
|
||||
from pyrogram.handlers.handler import Handler
|
||||
from pyrogram.methods import Methods
|
||||
@ -266,6 +266,9 @@ class Client(Methods):
|
||||
self.media_sessions = {}
|
||||
self.media_sessions_lock = asyncio.Lock()
|
||||
|
||||
self.save_file_lock = asyncio.Lock()
|
||||
self.get_file_lock = asyncio.Lock()
|
||||
|
||||
self.is_connected = None
|
||||
self.is_initialized = None
|
||||
|
||||
@ -795,49 +798,7 @@ class Client(Methods):
|
||||
progress: Callable = None,
|
||||
progress_args: tuple = ()
|
||||
) -> Optional[AsyncGenerator[bytes, None]]:
|
||||
dc_id = file_id.dc_id
|
||||
|
||||
async with self.media_sessions_lock:
|
||||
session = self.media_sessions.get(dc_id, None)
|
||||
|
||||
if session is None:
|
||||
if dc_id != await self.storage.dc_id():
|
||||
session = Session(
|
||||
self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
await session.start()
|
||||
|
||||
for _ in range(3):
|
||||
exported_auth = await self.invoke(
|
||||
raw.functions.auth.ExportAuthorization(
|
||||
dc_id=dc_id
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await session.invoke(
|
||||
raw.functions.auth.ImportAuthorization(
|
||||
id=exported_auth.id,
|
||||
bytes=exported_auth.bytes
|
||||
)
|
||||
)
|
||||
except AuthBytesInvalid:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
await session.stop()
|
||||
raise AuthBytesInvalid
|
||||
else:
|
||||
session = Session(
|
||||
self, dc_id, await self.storage.auth_key(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
await session.start()
|
||||
|
||||
self.media_sessions[dc_id] = session
|
||||
|
||||
async with self.get_file_lock:
|
||||
file_type = file_id.file_type
|
||||
|
||||
if file_type == FileType.CHAT_PHOTO:
|
||||
@ -882,7 +843,34 @@ class Client(Methods):
|
||||
chunk_size = 1024 * 1024
|
||||
offset_bytes = abs(offset) * chunk_size
|
||||
|
||||
dc_id = file_id.dc_id
|
||||
|
||||
session = Session(
|
||||
self, dc_id,
|
||||
await Auth(self, dc_id, await self.storage.test_mode()).create()
|
||||
if dc_id != await self.storage.dc_id()
|
||||
else await self.storage.auth_key(),
|
||||
await self.storage.test_mode(),
|
||||
is_media=True
|
||||
)
|
||||
|
||||
try:
|
||||
await session.start()
|
||||
|
||||
if dc_id != await self.storage.dc_id():
|
||||
exported_auth = await self.invoke(
|
||||
raw.functions.auth.ExportAuthorization(
|
||||
dc_id=dc_id
|
||||
)
|
||||
)
|
||||
|
||||
await session.invoke(
|
||||
raw.functions.auth.ImportAuthorization(
|
||||
id=exported_auth.id,
|
||||
bytes=exported_auth.bytes
|
||||
)
|
||||
)
|
||||
|
||||
r = await session.invoke(
|
||||
raw.functions.upload.GetFile(
|
||||
location=location,
|
||||
@ -929,20 +917,14 @@ class Client(Methods):
|
||||
)
|
||||
|
||||
elif isinstance(r, raw.types.upload.FileCdnRedirect):
|
||||
async with self.media_sessions_lock:
|
||||
cdn_session = self.media_sessions.get(r.dc_id, None)
|
||||
|
||||
if cdn_session is None:
|
||||
cdn_session = Session(
|
||||
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
|
||||
await self.storage.test_mode(), is_media=True, is_cdn=True
|
||||
)
|
||||
|
||||
try:
|
||||
await cdn_session.start()
|
||||
|
||||
self.media_sessions[r.dc_id] = cdn_session
|
||||
|
||||
try:
|
||||
while True:
|
||||
r2 = await cdn_session.invoke(
|
||||
raw.functions.upload.GetCdnFile(
|
||||
@ -1014,10 +996,14 @@ class Client(Methods):
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
await cdn_session.stop()
|
||||
except pyrogram.StopTransmission:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
finally:
|
||||
await session.stop()
|
||||
|
||||
def guess_mime_type(self, filename: str) -> Optional[str]:
|
||||
return self.mimetypes.guess_type(filename)[0]
|
||||
|
@ -94,6 +94,7 @@ class SaveFile:
|
||||
Raises:
|
||||
RPCError: In case of a Telegram RPC error.
|
||||
"""
|
||||
async with self.save_file_lock:
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user