mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-22 09:57:19 +00:00
Limit the amount of concurrent transmissions
This commit is contained in:
parent
b19764d5dc
commit
8441ce2f47
@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
|
||||
from pyrogram.errors import (
|
||||
SessionPasswordNeeded,
|
||||
VolumeLocNotFound, ChannelPrivate,
|
||||
AuthBytesInvalid, BadRequest
|
||||
BadRequest
|
||||
)
|
||||
from pyrogram.handlers.handler import Handler
|
||||
from pyrogram.methods import Methods
|
||||
@ -266,6 +266,9 @@ class Client(Methods):
|
||||
self.media_sessions = {}
|
||||
self.media_sessions_lock = asyncio.Lock()
|
||||
|
||||
self.save_file_lock = asyncio.Lock()
|
||||
self.get_file_lock = asyncio.Lock()
|
||||
|
||||
self.is_connected = None
|
||||
self.is_initialized = None
|
||||
|
||||
@ -795,204 +798,93 @@ class Client(Methods):
|
||||
progress: Callable = None,
|
||||
progress_args: tuple = ()
|
||||
) -> Optional[AsyncGenerator[bytes, None]]:
|
||||
dc_id = file_id.dc_id
|
||||
async with self.get_file_lock:
|
||||
file_type = file_id.file_type
|
||||
|
||||
async with self.media_sessions_lock:
|
||||
session = self.media_sessions.get(dc_id, None)
|
||||
|
||||
if session is None:
|
||||
if dc_id != await self.storage.dc_id():
|
||||
session = Session(
|
||||
self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
await session.start()
|
||||
|
||||
for _ in range(3):
|
||||
exported_auth = await self.invoke(
|
||||
raw.functions.auth.ExportAuthorization(
|
||||
dc_id=dc_id
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await session.invoke(
|
||||
raw.functions.auth.ImportAuthorization(
|
||||
id=exported_auth.id,
|
||||
bytes=exported_auth.bytes
|
||||
)
|
||||
)
|
||||
except AuthBytesInvalid:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
await session.stop()
|
||||
raise AuthBytesInvalid
|
||||
else:
|
||||
session = Session(
|
||||
self, dc_id, await self.storage.auth_key(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
await session.start()
|
||||
|
||||
self.media_sessions[dc_id] = session
|
||||
|
||||
file_type = file_id.file_type
|
||||
|
||||
if file_type == FileType.CHAT_PHOTO:
|
||||
if file_id.chat_id > 0:
|
||||
peer = raw.types.InputPeerUser(
|
||||
user_id=file_id.chat_id,
|
||||
access_hash=file_id.chat_access_hash
|
||||
)
|
||||
else:
|
||||
if file_id.chat_access_hash == 0:
|
||||
peer = raw.types.InputPeerChat(
|
||||
chat_id=-file_id.chat_id
|
||||
)
|
||||
else:
|
||||
peer = raw.types.InputPeerChannel(
|
||||
channel_id=utils.get_channel_id(file_id.chat_id),
|
||||
if file_type == FileType.CHAT_PHOTO:
|
||||
if file_id.chat_id > 0:
|
||||
peer = raw.types.InputPeerUser(
|
||||
user_id=file_id.chat_id,
|
||||
access_hash=file_id.chat_access_hash
|
||||
)
|
||||
|
||||
location = raw.types.InputPeerPhotoFileLocation(
|
||||
peer=peer,
|
||||
photo_id=file_id.media_id,
|
||||
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG
|
||||
)
|
||||
elif file_type == FileType.PHOTO:
|
||||
location = raw.types.InputPhotoFileLocation(
|
||||
id=file_id.media_id,
|
||||
access_hash=file_id.access_hash,
|
||||
file_reference=file_id.file_reference,
|
||||
thumb_size=file_id.thumbnail_size
|
||||
)
|
||||
else:
|
||||
location = raw.types.InputDocumentFileLocation(
|
||||
id=file_id.media_id,
|
||||
access_hash=file_id.access_hash,
|
||||
file_reference=file_id.file_reference,
|
||||
thumb_size=file_id.thumbnail_size
|
||||
)
|
||||
|
||||
current = 0
|
||||
total = abs(limit) or (1 << 31) - 1
|
||||
chunk_size = 1024 * 1024
|
||||
offset_bytes = abs(offset) * chunk_size
|
||||
|
||||
try:
|
||||
r = await session.invoke(
|
||||
raw.functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
),
|
||||
sleep_threshold=30
|
||||
)
|
||||
|
||||
if isinstance(r, raw.types.upload.File):
|
||||
while True:
|
||||
chunk = r.bytes
|
||||
|
||||
yield chunk
|
||||
|
||||
current += 1
|
||||
offset_bytes += chunk_size
|
||||
|
||||
if progress:
|
||||
func = functools.partial(
|
||||
progress,
|
||||
min(offset_bytes, file_size)
|
||||
if file_size != 0
|
||||
else offset_bytes,
|
||||
file_size,
|
||||
*progress_args
|
||||
else:
|
||||
if file_id.chat_access_hash == 0:
|
||||
peer = raw.types.InputPeerChat(
|
||||
chat_id=-file_id.chat_id
|
||||
)
|
||||
else:
|
||||
peer = raw.types.InputPeerChannel(
|
||||
channel_id=utils.get_channel_id(file_id.chat_id),
|
||||
access_hash=file_id.chat_access_hash
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(progress):
|
||||
await func()
|
||||
else:
|
||||
await self.loop.run_in_executor(self.executor, func)
|
||||
location = raw.types.InputPeerPhotoFileLocation(
|
||||
peer=peer,
|
||||
photo_id=file_id.media_id,
|
||||
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG
|
||||
)
|
||||
elif file_type == FileType.PHOTO:
|
||||
location = raw.types.InputPhotoFileLocation(
|
||||
id=file_id.media_id,
|
||||
access_hash=file_id.access_hash,
|
||||
file_reference=file_id.file_reference,
|
||||
thumb_size=file_id.thumbnail_size
|
||||
)
|
||||
else:
|
||||
location = raw.types.InputDocumentFileLocation(
|
||||
id=file_id.media_id,
|
||||
access_hash=file_id.access_hash,
|
||||
file_reference=file_id.file_reference,
|
||||
thumb_size=file_id.thumbnail_size
|
||||
)
|
||||
|
||||
if len(chunk) < chunk_size or current >= total:
|
||||
break
|
||||
current = 0
|
||||
total = abs(limit) or (1 << 31) - 1
|
||||
chunk_size = 1024 * 1024
|
||||
offset_bytes = abs(offset) * chunk_size
|
||||
|
||||
r = await session.invoke(
|
||||
raw.functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
),
|
||||
sleep_threshold=30
|
||||
dc_id = file_id.dc_id
|
||||
|
||||
session = Session(
|
||||
self, dc_id,
|
||||
await Auth(self, dc_id, await self.storage.test_mode()).create()
|
||||
if dc_id != await self.storage.dc_id()
|
||||
else await self.storage.auth_key(),
|
||||
await self.storage.test_mode(),
|
||||
is_media=True
|
||||
)
|
||||
|
||||
try:
|
||||
await session.start()
|
||||
|
||||
if dc_id != await self.storage.dc_id():
|
||||
exported_auth = await self.invoke(
|
||||
raw.functions.auth.ExportAuthorization(
|
||||
dc_id=dc_id
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(r, raw.types.upload.FileCdnRedirect):
|
||||
async with self.media_sessions_lock:
|
||||
cdn_session = self.media_sessions.get(r.dc_id, None)
|
||||
|
||||
if cdn_session is None:
|
||||
cdn_session = Session(
|
||||
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
|
||||
await self.storage.test_mode(), is_media=True, is_cdn=True
|
||||
await session.invoke(
|
||||
raw.functions.auth.ImportAuthorization(
|
||||
id=exported_auth.id,
|
||||
bytes=exported_auth.bytes
|
||||
)
|
||||
)
|
||||
|
||||
await cdn_session.start()
|
||||
r = await session.invoke(
|
||||
raw.functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
),
|
||||
sleep_threshold=30
|
||||
)
|
||||
|
||||
self.media_sessions[r.dc_id] = cdn_session
|
||||
|
||||
try:
|
||||
if isinstance(r, raw.types.upload.File):
|
||||
while True:
|
||||
r2 = await cdn_session.invoke(
|
||||
raw.functions.upload.GetCdnFile(
|
||||
file_token=r.file_token,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
)
|
||||
)
|
||||
chunk = r.bytes
|
||||
|
||||
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
|
||||
try:
|
||||
await session.invoke(
|
||||
raw.functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
)
|
||||
)
|
||||
except VolumeLocNotFound:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
chunk = r2.bytes
|
||||
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = aes.ctr256_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
bytearray(
|
||||
r.encryption_iv[:-4]
|
||||
+ (offset_bytes // 16).to_bytes(4, "big")
|
||||
)
|
||||
)
|
||||
|
||||
hashes = await session.invoke(
|
||||
raw.functions.upload.GetCdnFileHashes(
|
||||
file_token=r.file_token,
|
||||
offset=offset_bytes
|
||||
)
|
||||
)
|
||||
|
||||
# https://core.telegram.org/cdn#verifying-files
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
CDNFileHashMismatch.check(
|
||||
h.hash == sha256(cdn_chunk).digest(),
|
||||
"h.hash == sha256(cdn_chunk).digest()"
|
||||
)
|
||||
|
||||
yield decrypted_chunk
|
||||
yield chunk
|
||||
|
||||
current += 1
|
||||
offset_bytes += chunk_size
|
||||
@ -1000,7 +892,9 @@ class Client(Methods):
|
||||
if progress:
|
||||
func = functools.partial(
|
||||
progress,
|
||||
min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
|
||||
min(offset_bytes, file_size)
|
||||
if file_size != 0
|
||||
else offset_bytes,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
@ -1012,12 +906,104 @@ class Client(Methods):
|
||||
|
||||
if len(chunk) < chunk_size or current >= total:
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
except pyrogram.StopTransmission:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
r = await session.invoke(
|
||||
raw.functions.upload.GetFile(
|
||||
location=location,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
),
|
||||
sleep_threshold=30
|
||||
)
|
||||
|
||||
elif isinstance(r, raw.types.upload.FileCdnRedirect):
|
||||
cdn_session = Session(
|
||||
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
|
||||
await self.storage.test_mode(), is_media=True, is_cdn=True
|
||||
)
|
||||
|
||||
try:
|
||||
await cdn_session.start()
|
||||
|
||||
while True:
|
||||
r2 = await cdn_session.invoke(
|
||||
raw.functions.upload.GetCdnFile(
|
||||
file_token=r.file_token,
|
||||
offset=offset_bytes,
|
||||
limit=chunk_size
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded):
|
||||
try:
|
||||
await session.invoke(
|
||||
raw.functions.upload.ReuploadCdnFile(
|
||||
file_token=r.file_token,
|
||||
request_token=r2.request_token
|
||||
)
|
||||
)
|
||||
except VolumeLocNotFound:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
chunk = r2.bytes
|
||||
|
||||
# https://core.telegram.org/cdn#decrypting-files
|
||||
decrypted_chunk = aes.ctr256_decrypt(
|
||||
chunk,
|
||||
r.encryption_key,
|
||||
bytearray(
|
||||
r.encryption_iv[:-4]
|
||||
+ (offset_bytes // 16).to_bytes(4, "big")
|
||||
)
|
||||
)
|
||||
|
||||
hashes = await session.invoke(
|
||||
raw.functions.upload.GetCdnFileHashes(
|
||||
file_token=r.file_token,
|
||||
offset=offset_bytes
|
||||
)
|
||||
)
|
||||
|
||||
# https://core.telegram.org/cdn#verifying-files
|
||||
for i, h in enumerate(hashes):
|
||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||
CDNFileHashMismatch.check(
|
||||
h.hash == sha256(cdn_chunk).digest(),
|
||||
"h.hash == sha256(cdn_chunk).digest()"
|
||||
)
|
||||
|
||||
yield decrypted_chunk
|
||||
|
||||
current += 1
|
||||
offset_bytes += chunk_size
|
||||
|
||||
if progress:
|
||||
func = functools.partial(
|
||||
progress,
|
||||
min(offset_bytes, file_size) if file_size != 0 else offset_bytes,
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(progress):
|
||||
await func()
|
||||
else:
|
||||
await self.loop.run_in_executor(self.executor, func)
|
||||
|
||||
if len(chunk) < chunk_size or current >= total:
|
||||
break
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
await cdn_session.stop()
|
||||
except pyrogram.StopTransmission:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
finally:
|
||||
await session.stop()
|
||||
|
||||
def guess_mime_type(self, filename: str) -> Optional[str]:
|
||||
return self.mimetypes.guess_type(filename)[0]
|
||||
|
@ -94,132 +94,133 @@ class SaveFile:
|
||||
Raises:
|
||||
RPCError: In case of a Telegram RPC error.
|
||||
"""
|
||||
if path is None:
|
||||
return None
|
||||
async with self.save_file_lock:
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
async def worker(session):
|
||||
while True:
|
||||
data = await queue.get()
|
||||
async def worker(session):
|
||||
while True:
|
||||
data = await queue.get()
|
||||
|
||||
if data is None:
|
||||
return
|
||||
if data is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await session.invoke(data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
try:
|
||||
await session.invoke(data)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
part_size = 512 * 1024
|
||||
|
||||
if isinstance(path, (str, PurePath)):
|
||||
fp = open(path, "rb")
|
||||
elif isinstance(path, io.IOBase):
|
||||
fp = path
|
||||
else:
|
||||
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
|
||||
|
||||
file_name = getattr(fp, "name", "file.jpg")
|
||||
|
||||
fp.seek(0, os.SEEK_END)
|
||||
file_size = fp.tell()
|
||||
fp.seek(0)
|
||||
|
||||
if file_size == 0:
|
||||
raise ValueError("File size equals to 0 B")
|
||||
|
||||
file_size_limit_mib = 4000 if self.me.is_premium else 2000
|
||||
|
||||
if file_size > file_size_limit_mib * 1024 * 1024:
|
||||
raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB")
|
||||
|
||||
file_total_parts = int(math.ceil(file_size / part_size))
|
||||
is_big = file_size > 10 * 1024 * 1024
|
||||
workers_count = 4 if is_big else 1
|
||||
is_missing_part = file_id is not None
|
||||
file_id = file_id or self.rnd_id()
|
||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||
session = Session(
|
||||
self, await self.storage.dc_id(), await self.storage.auth_key(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
|
||||
queue = asyncio.Queue(1)
|
||||
|
||||
try:
|
||||
await session.start()
|
||||
|
||||
fp.seek(part_size * file_part)
|
||||
|
||||
while True:
|
||||
chunk = fp.read(part_size)
|
||||
|
||||
if not chunk:
|
||||
if not is_big and not is_missing_part:
|
||||
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
|
||||
break
|
||||
|
||||
if is_big:
|
||||
rpc = raw.functions.upload.SaveBigFilePart(
|
||||
file_id=file_id,
|
||||
file_part=file_part,
|
||||
file_total_parts=file_total_parts,
|
||||
bytes=chunk
|
||||
)
|
||||
else:
|
||||
rpc = raw.functions.upload.SaveFilePart(
|
||||
file_id=file_id,
|
||||
file_part=file_part,
|
||||
bytes=chunk
|
||||
)
|
||||
|
||||
await queue.put(rpc)
|
||||
|
||||
if is_missing_part:
|
||||
return
|
||||
|
||||
if not is_big and not is_missing_part:
|
||||
md5_sum.update(chunk)
|
||||
|
||||
file_part += 1
|
||||
|
||||
if progress:
|
||||
func = functools.partial(
|
||||
progress,
|
||||
min(file_part * part_size, file_size),
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(progress):
|
||||
await func()
|
||||
else:
|
||||
await self.loop.run_in_executor(self.executor, func)
|
||||
except StopTransmission:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
else:
|
||||
if is_big:
|
||||
return raw.types.InputFileBig(
|
||||
id=file_id,
|
||||
parts=file_total_parts,
|
||||
name=file_name,
|
||||
|
||||
)
|
||||
else:
|
||||
return raw.types.InputFile(
|
||||
id=file_id,
|
||||
parts=file_total_parts,
|
||||
name=file_name,
|
||||
md5_checksum=md5_sum
|
||||
)
|
||||
finally:
|
||||
for _ in workers:
|
||||
await queue.put(None)
|
||||
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
await session.stop()
|
||||
part_size = 512 * 1024
|
||||
|
||||
if isinstance(path, (str, PurePath)):
|
||||
fp.close()
|
||||
fp = open(path, "rb")
|
||||
elif isinstance(path, io.IOBase):
|
||||
fp = path
|
||||
else:
|
||||
raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer")
|
||||
|
||||
file_name = getattr(fp, "name", "file.jpg")
|
||||
|
||||
fp.seek(0, os.SEEK_END)
|
||||
file_size = fp.tell()
|
||||
fp.seek(0)
|
||||
|
||||
if file_size == 0:
|
||||
raise ValueError("File size equals to 0 B")
|
||||
|
||||
file_size_limit_mib = 4000 if self.me.is_premium else 2000
|
||||
|
||||
if file_size > file_size_limit_mib * 1024 * 1024:
|
||||
raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB")
|
||||
|
||||
file_total_parts = int(math.ceil(file_size / part_size))
|
||||
is_big = file_size > 10 * 1024 * 1024
|
||||
workers_count = 4 if is_big else 1
|
||||
is_missing_part = file_id is not None
|
||||
file_id = file_id or self.rnd_id()
|
||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||
session = Session(
|
||||
self, await self.storage.dc_id(), await self.storage.auth_key(),
|
||||
await self.storage.test_mode(), is_media=True
|
||||
)
|
||||
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
|
||||
queue = asyncio.Queue(1)
|
||||
|
||||
try:
|
||||
await session.start()
|
||||
|
||||
fp.seek(part_size * file_part)
|
||||
|
||||
while True:
|
||||
chunk = fp.read(part_size)
|
||||
|
||||
if not chunk:
|
||||
if not is_big and not is_missing_part:
|
||||
md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()])
|
||||
break
|
||||
|
||||
if is_big:
|
||||
rpc = raw.functions.upload.SaveBigFilePart(
|
||||
file_id=file_id,
|
||||
file_part=file_part,
|
||||
file_total_parts=file_total_parts,
|
||||
bytes=chunk
|
||||
)
|
||||
else:
|
||||
rpc = raw.functions.upload.SaveFilePart(
|
||||
file_id=file_id,
|
||||
file_part=file_part,
|
||||
bytes=chunk
|
||||
)
|
||||
|
||||
await queue.put(rpc)
|
||||
|
||||
if is_missing_part:
|
||||
return
|
||||
|
||||
if not is_big and not is_missing_part:
|
||||
md5_sum.update(chunk)
|
||||
|
||||
file_part += 1
|
||||
|
||||
if progress:
|
||||
func = functools.partial(
|
||||
progress,
|
||||
min(file_part * part_size, file_size),
|
||||
file_size,
|
||||
*progress_args
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(progress):
|
||||
await func()
|
||||
else:
|
||||
await self.loop.run_in_executor(self.executor, func)
|
||||
except StopTransmission:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
else:
|
||||
if is_big:
|
||||
return raw.types.InputFileBig(
|
||||
id=file_id,
|
||||
parts=file_total_parts,
|
||||
name=file_name,
|
||||
|
||||
)
|
||||
else:
|
||||
return raw.types.InputFile(
|
||||
id=file_id,
|
||||
parts=file_total_parts,
|
||||
name=file_name,
|
||||
md5_checksum=md5_sum
|
||||
)
|
||||
finally:
|
||||
for _ in workers:
|
||||
await queue.put(None)
|
||||
|
||||
await asyncio.gather(*workers)
|
||||
|
||||
await session.stop()
|
||||
|
||||
if isinstance(path, (str, PurePath)):
|
||||
fp.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user