2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-22 01:47:34 +00:00

Limit the amount of concurrent transmissions

This commit is contained in:
Dan 2023-01-08 17:11:02 +01:00
parent b19764d5dc
commit 8441ce2f47
2 changed files with 302 additions and 315 deletions

View File

@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import (
SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate,
AuthBytesInvalid, BadRequest
BadRequest
)
from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods
@ -266,6 +266,9 @@ class Client(Methods):
self.media_sessions = {}
self.media_sessions_lock = asyncio.Lock()
self.save_file_lock = asyncio.Lock()
self.get_file_lock = asyncio.Lock()
self.is_connected = None
self.is_initialized = None
@ -795,204 +798,93 @@ class Client(Methods):
progress: Callable = None,
progress_args: tuple = ()
) -> Optional[AsyncGenerator[bytes, None]]:
dc_id = file_id.dc_id
async with self.get_file_lock:
file_type = file_id.file_type
async with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != await self.storage.dc_id():
session = Session(
self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True
)
await session.start()
for _ in range(3):
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
try:
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
except AuthBytesInvalid:
continue
else:
break
else:
await session.stop()
raise AuthBytesInvalid
else:
session = Session(
self, dc_id, await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
)
await session.start()
self.media_sessions[dc_id] = session
file_type = file_id.file_type
if file_type == FileType.CHAT_PHOTO:
if file_id.chat_id > 0:
peer = raw.types.InputPeerUser(
user_id=file_id.chat_id,
access_hash=file_id.chat_access_hash
)
else:
if file_id.chat_access_hash == 0:
peer = raw.types.InputPeerChat(
chat_id=-file_id.chat_id
)
else:
peer = raw.types.InputPeerChannel(
channel_id=utils.get_channel_id(file_id.chat_id),
if file_type == FileType.CHAT_PHOTO:
if file_id.chat_id > 0:
peer = raw.types.InputPeerUser(
user_id=file_id.chat_id,
access_hash=file_id.chat_access_hash
)
location = raw.types.InputPeerPhotoFileLocation(
peer=peer,
photo_id=file_id.media_id,
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG
)
elif file_type == FileType.PHOTO:
location = raw.types.InputPhotoFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size
)
else:
location = raw.types.InputDocumentFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size
)
current = 0
total = abs(limit) or (1 << 31) - 1
chunk_size = 1024 * 1024
offset_bytes = abs(offset) * chunk_size
try:
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset_bytes,
limit=chunk_size
),
sleep_threshold=30
)
if isinstance(r, raw.types.upload.File):
while True:
chunk = r.bytes
yield chunk
current += 1
offset_bytes += chunk_size
if progress:
func = functools.partial(
progress,
min(offset_bytes, file_size)
if file_size != 0
else offset_bytes,
file_size,
*progress_args
else:
if file_id.chat_access_hash == 0:
peer = raw.types.InputPeerChat(
chat_id=-file_id.chat_id
)
else:
peer = raw.types.InputPeerChannel(
channel_id=utils.get_channel_id(file_id.chat_id),
access_hash=file_id.chat_access_hash
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
location = raw.types.InputPeerPhotoFileLocation(
peer=peer,
photo_id=file_id.media_id,
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG
)
elif file_type == FileType.PHOTO:
location = raw.types.InputPhotoFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size
)
else:
location = raw.types.InputDocumentFileLocation(
id=file_id.media_id,
access_hash=file_id.access_hash,
file_reference=file_id.file_reference,
thumb_size=file_id.thumbnail_size
)
if len(chunk) < chunk_size or current >= total:
break
current = 0
total = abs(limit) or (1 << 31) - 1
chunk_size = 1024 * 1024
offset_bytes = abs(offset) * chunk_size
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset_bytes,
limit=chunk_size
),
sleep_threshold=30
dc_id = file_id.dc_id
session = Session(
self, dc_id,
await Auth(self, dc_id, await self.storage.test_mode()).create()
if dc_id != await self.storage.dc_id()
else await self.storage.auth_key(),
await self.storage.test_mode(),
is_media=True
)
try:
await session.start()
if dc_id != await self.storage.dc_id():
exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization(
dc_id=dc_id
)
)
elif isinstance(r, raw.types.upload.FileCdnRedirect):
async with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
cdn_session = Session(
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True, is_cdn=True
await session.invoke(
raw.functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
)
)
await cdn_session.start()
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset_bytes,
limit=chunk_size
),
sleep_threshold=30
)
self.media_sessions[r.dc_id] = cdn_session
try:
if isinstance(r, raw.types.upload.File):
while True:
r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset_bytes,
limit=chunk_size
)
)
chunk = r.bytes
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
try:
await session.invoke(
raw.functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = aes.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset_bytes // 16).to_bytes(4, "big")
)
)
hashes = await session.invoke(
raw.functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset_bytes
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(
h.hash == sha256(cdn_chunk).digest(),
"h.hash == sha256(cdn_chunk).digest()"
)
yield decrypted_chunk
yield chunk
current += 1
offset_bytes += chunk_size
@ -1000,7 +892,9 @@ class Client(Methods):
if progress:
func = functools.partial(
progress,
min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
min(offset_bytes, file_size)
if file_size != 0
else offset_bytes,
file_size,
*progress_args
)
@ -1012,12 +906,104 @@ class Client(Methods):
if len(chunk) < chunk_size or current >= total:
break
except Exception as e:
raise e
except pyrogram.StopTransmission:
raise
except Exception as e:
log.exception(e)
r = await session.invoke(
raw.functions.upload.GetFile(
location=location,
offset=offset_bytes,
limit=chunk_size
),
sleep_threshold=30
)
elif isinstance(r, raw.types.upload.FileCdnRedirect):
cdn_session = Session(
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True, is_cdn=True
)
try:
await cdn_session.start()
while True:
r2 = await cdn_session.invoke(
raw.functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset_bytes,
limit=chunk_size
)
)
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
try:
await session.invoke(
raw.functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
)
)
except VolumeLocNotFound:
break
else:
continue
chunk = r2.bytes
# https://core.telegram.org/cdn#decrypting-files
decrypted_chunk = aes.ctr256_decrypt(
chunk,
r.encryption_key,
bytearray(
r.encryption_iv[:-4]
+ (offset_bytes // 16).to_bytes(4, "big")
)
)
hashes = await session.invoke(
raw.functions.upload.GetCdnFileHashes(
file_token=r.file_token,
offset=offset_bytes
)
)
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
CDNFileHashMismatch.check(
h.hash == sha256(cdn_chunk).digest(),
"h.hash == sha256(cdn_chunk).digest()"
)
yield decrypted_chunk
current += 1
offset_bytes += chunk_size
if progress:
func = functools.partial(
progress,
min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
if len(chunk) < chunk_size or current >= total:
break
except Exception as e:
raise e
finally:
await cdn_session.stop()
except pyrogram.StopTransmission:
raise
except Exception as e:
log.exception(e)
finally:
await session.stop()
def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0]

View File

@ -94,132 +94,133 @@ class SaveFile:
Raises:
RPCError: In case of a Telegram RPC error.
"""
if path is None:
return None
async with self.save_file_lock:
if path is None:
return None
async def worker(session):
while True:
data = await queue.get()
async def worker(session):
while True:
data = await queue.get()
if data is None:
return
if data is None:
return
try:
await session.invoke(data)
except Exception as e:
log.exception(e)
try:
await session.invoke(data)
except Exception as e:
log.exception(e)
part_size = 512 * 1024
if isinstance(path, (str, PurePath)):
fp = open(path, "rb")
elif isinstance(path, io.IOBase):
fp = path
else:
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
file_name = getattr(fp, "name", "file.jpg")
fp.seek(0, os.SEEK_END)
file_size = fp.tell()
fp.seek(0)
if file_size == 0:
raise ValueError("File size equals to 0 B")
file_size_limit_mib = 4000 if self.me.is_premium else 2000
if file_size > file_size_limit_mib * 1024 * 1024:
raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB")
file_total_parts = int(math.ceil(file_size / part_size))
is_big = file_size > 10 * 1024 * 1024
workers_count = 4 if is_big else 1
is_missing_part = file_id is not None
file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(
self, await self.storage.dc_id(), await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
)
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
queue = asyncio.Queue(1)
try:
await session.start()
fp.seek(part_size * file_part)
while True:
chunk = fp.read(part_size)
if not chunk:
if not is_big and not is_missing_part:
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
if is_big:
rpc = raw.functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = raw.functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
await queue.put(rpc)
if is_missing_part:
return
if not is_big and not is_missing_part:
md5_sum.update(chunk)
file_part += 1
if progress:
func = functools.partial(
progress,
min(file_part * part_size, file_size),
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
except StopTransmission:
raise
except Exception as e:
log.exception(e)
else:
if is_big:
return raw.types.InputFileBig(
id=file_id,
parts=file_total_parts,
name=file_name,
)
else:
return raw.types.InputFile(
id=file_id,
parts=file_total_parts,
name=file_name,
md5_checksum=md5_sum
)
finally:
for _ in workers:
await queue.put(None)
await asyncio.gather(*workers)
await session.stop()
part_size = 512 * 1024
if isinstance(path, (str, PurePath)):
fp.close()
fp = open(path, "rb")
elif isinstance(path, io.IOBase):
fp = path
else:
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
file_name = getattr(fp, "name", "file.jpg")
fp.seek(0, os.SEEK_END)
file_size = fp.tell()
fp.seek(0)
if file_size == 0:
raise ValueError("File size equals to 0 B")
file_size_limit_mib = 4000 if self.me.is_premium else 2000
if file_size > file_size_limit_mib * 1024 * 1024:
raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB")
file_total_parts = int(math.ceil(file_size / part_size))
is_big = file_size > 10 * 1024 * 1024
workers_count = 4 if is_big else 1
is_missing_part = file_id is not None
file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(
self, await self.storage.dc_id(), await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
)
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
queue = asyncio.Queue(1)
try:
await session.start()
fp.seek(part_size * file_part)
while True:
chunk = fp.read(part_size)
if not chunk:
if not is_big and not is_missing_part:
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
break
if is_big:
rpc = raw.functions.upload.SaveBigFilePart(
file_id=file_id,
file_part=file_part,
file_total_parts=file_total_parts,
bytes=chunk
)
else:
rpc = raw.functions.upload.SaveFilePart(
file_id=file_id,
file_part=file_part,
bytes=chunk
)
await queue.put(rpc)
if is_missing_part:
return
if not is_big and not is_missing_part:
md5_sum.update(chunk)
file_part += 1
if progress:
func = functools.partial(
progress,
min(file_part * part_size, file_size),
file_size,
*progress_args
)
if inspect.iscoroutinefunction(progress):
await func()
else:
await self.loop.run_in_executor(self.executor, func)
except StopTransmission:
raise
except Exception as e:
log.exception(e)
else:
if is_big:
return raw.types.InputFileBig(
id=file_id,
parts=file_total_parts,
name=file_name,
)
else:
return raw.types.InputFile(
id=file_id,
parts=file_total_parts,
name=file_name,
md5_checksum=md5_sum
)
finally:
for _ in workers:
await queue.put(None)
await asyncio.gather(*workers)
await session.stop()
if isinstance(path, (str, PurePath)):
fp.close()