diff --git a/pyrogram/client.py b/pyrogram/client.py index 7848c1f5..81828f5a 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch from pyrogram.errors import ( SessionPasswordNeeded, VolumeLocNotFound, ChannelPrivate, - AuthBytesInvalid, BadRequest + BadRequest ) from pyrogram.handlers.handler import Handler from pyrogram.methods import Methods @@ -266,6 +266,9 @@ class Client(Methods): self.media_sessions = {} self.media_sessions_lock = asyncio.Lock() + self.save_file_lock = asyncio.Lock() + self.get_file_lock = asyncio.Lock() + self.is_connected = None self.is_initialized = None @@ -795,204 +798,93 @@ class Client(Methods): progress: Callable = None, progress_args: tuple = () ) -> Optional[AsyncGenerator[bytes, None]]: - dc_id = file_id.dc_id + async with self.get_file_lock: + file_type = file_id.file_type - async with self.media_sessions_lock: - session = self.media_sessions.get(dc_id, None) - - if session is None: - if dc_id != await self.storage.dc_id(): - session = Session( - self, dc_id, await Auth(self, dc_id, await self.storage.test_mode()).create(), - await self.storage.test_mode(), is_media=True - ) - await session.start() - - for _ in range(3): - exported_auth = await self.invoke( - raw.functions.auth.ExportAuthorization( - dc_id=dc_id - ) - ) - - try: - await session.invoke( - raw.functions.auth.ImportAuthorization( - id=exported_auth.id, - bytes=exported_auth.bytes - ) - ) - except AuthBytesInvalid: - continue - else: - break - else: - await session.stop() - raise AuthBytesInvalid - else: - session = Session( - self, dc_id, await self.storage.auth_key(), - await self.storage.test_mode(), is_media=True - ) - await session.start() - - self.media_sessions[dc_id] = session - - file_type = file_id.file_type - - if file_type == FileType.CHAT_PHOTO: - if file_id.chat_id > 0: - peer = raw.types.InputPeerUser( - user_id=file_id.chat_id, - access_hash=file_id.chat_access_hash - ) - else: - if file_id.chat_access_hash == 0: - peer = raw.types.InputPeerChat( - chat_id=-file_id.chat_id - ) - else: - peer = raw.types.InputPeerChannel( - channel_id=utils.get_channel_id(file_id.chat_id), + if file_type == FileType.CHAT_PHOTO: + if file_id.chat_id > 0: + peer = raw.types.InputPeerUser( + user_id=file_id.chat_id, access_hash=file_id.chat_access_hash ) - - location = raw.types.InputPeerPhotoFileLocation( - peer=peer, - photo_id=file_id.media_id, - big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG - ) - elif file_type == FileType.PHOTO: - location = raw.types.InputPhotoFileLocation( - id=file_id.media_id, - access_hash=file_id.access_hash, - file_reference=file_id.file_reference, - thumb_size=file_id.thumbnail_size - ) - else: - location = raw.types.InputDocumentFileLocation( - id=file_id.media_id, - access_hash=file_id.access_hash, - file_reference=file_id.file_reference, - thumb_size=file_id.thumbnail_size - ) - - current = 0 - total = abs(limit) or (1 << 31) - 1 - chunk_size = 1024 * 1024 - offset_bytes = abs(offset) * chunk_size - - try: - r = await session.invoke( - raw.functions.upload.GetFile( - location=location, - offset=offset_bytes, - limit=chunk_size - ), - sleep_threshold=30 - ) - - if isinstance(r, raw.types.upload.File): - while True: - chunk = r.bytes - - yield chunk - - current += 1 - offset_bytes += chunk_size - - if progress: - func = functools.partial( - progress, - min(offset_bytes, file_size) - if file_size != 0 - else offset_bytes, - file_size, - *progress_args + else: + if file_id.chat_access_hash == 0: + peer = raw.types.InputPeerChat( + chat_id=-file_id.chat_id + ) + else: + peer = raw.types.InputPeerChannel( + channel_id=utils.get_channel_id(file_id.chat_id), + access_hash=file_id.chat_access_hash ) - if inspect.iscoroutinefunction(progress): - await func() - else: - await self.loop.run_in_executor(self.executor, func) + location = raw.types.InputPeerPhotoFileLocation( + peer=peer, + photo_id=file_id.media_id, + big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG + ) + elif file_type == FileType.PHOTO: + location = raw.types.InputPhotoFileLocation( + id=file_id.media_id, + access_hash=file_id.access_hash, + file_reference=file_id.file_reference, + thumb_size=file_id.thumbnail_size + ) + else: + location = raw.types.InputDocumentFileLocation( + id=file_id.media_id, + access_hash=file_id.access_hash, + file_reference=file_id.file_reference, + thumb_size=file_id.thumbnail_size + ) - if len(chunk) < chunk_size or current >= total: - break + current = 0 + total = abs(limit) or (1 << 31) - 1 + chunk_size = 1024 * 1024 + offset_bytes = abs(offset) * chunk_size - r = await session.invoke( - raw.functions.upload.GetFile( - location=location, - offset=offset_bytes, - limit=chunk_size - ), - sleep_threshold=30 + dc_id = file_id.dc_id + + session = Session( + self, dc_id, + await Auth(self, dc_id, await self.storage.test_mode()).create() + if dc_id != await self.storage.dc_id() + else await self.storage.auth_key(), + await self.storage.test_mode(), + is_media=True + ) + + try: + await session.start() + + if dc_id != await self.storage.dc_id(): + exported_auth = await self.invoke( + raw.functions.auth.ExportAuthorization( + dc_id=dc_id + ) ) - elif isinstance(r, raw.types.upload.FileCdnRedirect): - async with self.media_sessions_lock: - cdn_session = self.media_sessions.get(r.dc_id, None) - - if cdn_session is None: - cdn_session = Session( - self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(), - await self.storage.test_mode(), is_media=True, is_cdn=True + await session.invoke( + raw.functions.auth.ImportAuthorization( + id=exported_auth.id, + bytes=exported_auth.bytes ) + ) - await cdn_session.start() + r = await session.invoke( + raw.functions.upload.GetFile( + location=location, + offset=offset_bytes, + limit=chunk_size + ), + sleep_threshold=30 + ) - self.media_sessions[r.dc_id] = cdn_session - - try: + if isinstance(r, raw.types.upload.File): while True: - r2 = await cdn_session.invoke( - raw.functions.upload.GetCdnFile( - file_token=r.file_token, - offset=offset_bytes, - limit=chunk_size - ) - ) + chunk = r.bytes - if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded): - try: - await session.invoke( - raw.functions.upload.ReuploadCdnFile( - file_token=r.file_token, - request_token=r2.request_token - ) - ) - except VolumeLocNotFound: - break - else: - continue - - chunk = r2.bytes - - # https://core.telegram.org/cdn#decrypting-files - decrypted_chunk = aes.ctr256_decrypt( - chunk, - r.encryption_key, - bytearray( - r.encryption_iv[:-4] - + (offset_bytes // 16).to_bytes(4, "big") - ) - ) - - hashes = await session.invoke( - raw.functions.upload.GetCdnFileHashes( - file_token=r.file_token, - offset=offset_bytes - ) - ) - - # https://core.telegram.org/cdn#verifying-files - for i, h in enumerate(hashes): - cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - CDNFileHashMismatch.check( - h.hash == sha256(cdn_chunk).digest(), - "h.hash == sha256(cdn_chunk).digest()" - ) - - yield decrypted_chunk + yield chunk current += 1 offset_bytes += chunk_size @@ -1000,7 +892,9 @@ class Client(Methods): if progress: func = functools.partial( progress, - min(offset_bytes, file_size) if file_size != 0 else offset_bytes, + min(offset_bytes, file_size) + if file_size != 0 + else offset_bytes, file_size, *progress_args ) @@ -1012,12 +906,104 @@ class Client(Methods): if len(chunk) < chunk_size or current >= total: break - except Exception as e: - raise e - except pyrogram.StopTransmission: - raise - except Exception as e: - log.exception(e) + + r = await session.invoke( + raw.functions.upload.GetFile( + location=location, + offset=offset_bytes, + limit=chunk_size + ), + sleep_threshold=30 + ) + + elif isinstance(r, raw.types.upload.FileCdnRedirect): + cdn_session = Session( + self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(), + await self.storage.test_mode(), is_media=True, is_cdn=True + ) + + try: + await cdn_session.start() + + while True: + r2 = await cdn_session.invoke( + raw.functions.upload.GetCdnFile( + file_token=r.file_token, + offset=offset_bytes, + limit=chunk_size + ) + ) + + if isinstance(r2, raw.types.upload.CdnFileReuploadNeeded): + try: + await session.invoke( + raw.functions.upload.ReuploadCdnFile( + file_token=r.file_token, + request_token=r2.request_token + ) + ) + except VolumeLocNotFound: + break + else: + continue + + chunk = r2.bytes + + # https://core.telegram.org/cdn#decrypting-files + decrypted_chunk = aes.ctr256_decrypt( + chunk, + r.encryption_key, + bytearray( + r.encryption_iv[:-4] + + (offset_bytes // 16).to_bytes(4, "big") + ) + ) + + hashes = await session.invoke( + raw.functions.upload.GetCdnFileHashes( + file_token=r.file_token, + offset=offset_bytes + ) + ) + + # https://core.telegram.org/cdn#verifying-files + for i, h in enumerate(hashes): + cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] + CDNFileHashMismatch.check( + h.hash == sha256(cdn_chunk).digest(), + "h.hash == sha256(cdn_chunk).digest()" + ) + + yield decrypted_chunk + + current += 1 + offset_bytes += chunk_size + + if progress: + func = functools.partial( + progress, + min(offset_bytes, file_size) if file_size != 0 else offset_bytes, + file_size, + *progress_args + ) + + if inspect.iscoroutinefunction(progress): + await func() + else: + await self.loop.run_in_executor(self.executor, func) + + if len(chunk) < chunk_size or current >= total: + break + except Exception as e: + raise e + finally: + await cdn_session.stop() + except pyrogram.StopTransmission: + raise + except Exception as e: + log.exception(e) + finally: + await session.stop() def guess_mime_type(self, filename: str) -> Optional[str]: return self.mimetypes.guess_type(filename)[0] diff --git a/pyrogram/methods/advanced/save_file.py b/pyrogram/methods/advanced/save_file.py index 5ecac6d8..e683fe52 100644 --- a/pyrogram/methods/advanced/save_file.py +++ b/pyrogram/methods/advanced/save_file.py @@ -94,132 +94,133 @@ class SaveFile: Raises: RPCError: In case of a Telegram RPC error. """ - if path is None: - return None + async with self.save_file_lock: + if path is None: + return None - async def worker(session): - while True: - data = await queue.get() + async def worker(session): + while True: + data = await queue.get() - if data is None: - return + if data is None: + return - try: - await session.invoke(data) - except Exception as e: - log.exception(e) + try: + await session.invoke(data) + except Exception as e: + log.exception(e) - part_size = 512 * 1024 - - if isinstance(path, (str, PurePath)): - fp = open(path, "rb") - elif isinstance(path, io.IOBase): - fp = path - else: - raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer") - - file_name = getattr(fp, "name", "file.jpg") - - fp.seek(0, os.SEEK_END) - file_size = fp.tell() - fp.seek(0) - - if file_size == 0: - raise ValueError("File size equals to 0 B") - - file_size_limit_mib = 4000 if self.me.is_premium else 2000 - - if file_size > file_size_limit_mib * 1024 * 1024: - raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB") - - file_total_parts = int(math.ceil(file_size / part_size)) - is_big = file_size > 10 * 1024 * 1024 - workers_count = 4 if is_big else 1 - is_missing_part = file_id is not None - file_id = file_id or self.rnd_id() - md5_sum = md5() if not is_big and not is_missing_part else None - session = Session( - self, await self.storage.dc_id(), await self.storage.auth_key(), - await self.storage.test_mode(), is_media=True - ) - workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)] - queue = asyncio.Queue(1) - - try: - await session.start() - - fp.seek(part_size * file_part) - - while True: - chunk = fp.read(part_size) - - if not chunk: - if not is_big and not is_missing_part: - md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) - break - - if is_big: - rpc = raw.functions.upload.SaveBigFilePart( - file_id=file_id, - file_part=file_part, - file_total_parts=file_total_parts, - bytes=chunk - ) - else: - rpc = raw.functions.upload.SaveFilePart( - file_id=file_id, - file_part=file_part, - bytes=chunk - ) - - await queue.put(rpc) - - if is_missing_part: - return - - if not is_big and not is_missing_part: - md5_sum.update(chunk) - - file_part += 1 - - if progress: - func = functools.partial( - progress, - min(file_part * part_size, file_size), - file_size, - *progress_args - ) - - if inspect.iscoroutinefunction(progress): - await func() - else: - await self.loop.run_in_executor(self.executor, func) - except StopTransmission: - raise - except Exception as e: - log.exception(e) - else: - if is_big: - return raw.types.InputFileBig( - id=file_id, - parts=file_total_parts, - name=file_name, - - ) - else: - return raw.types.InputFile( - id=file_id, - parts=file_total_parts, - name=file_name, - md5_checksum=md5_sum - ) - finally: - for _ in workers: - await queue.put(None) - - await asyncio.gather(*workers) - - await session.stop() + part_size = 512 * 1024 if isinstance(path, (str, PurePath)): - fp.close() + fp = open(path, "rb") + elif isinstance(path, io.IOBase): + fp = path + else: + raise ValueError("Invalid file. Expected a file path as string or a binary (not text) file pointer") + + file_name = getattr(fp, "name", "file.jpg") + + fp.seek(0, os.SEEK_END) + file_size = fp.tell() + fp.seek(0) + + if file_size == 0: + raise ValueError("File size equals to 0 B") + + file_size_limit_mib = 4000 if self.me.is_premium else 2000 + + if file_size > file_size_limit_mib * 1024 * 1024: + raise ValueError(f"Can't upload files bigger than {file_size_limit_mib} MiB") + + file_total_parts = int(math.ceil(file_size / part_size)) + is_big = file_size > 10 * 1024 * 1024 + workers_count = 4 if is_big else 1 + is_missing_part = file_id is not None + file_id = file_id or self.rnd_id() + md5_sum = md5() if not is_big and not is_missing_part else None + session = Session( + self, await self.storage.dc_id(), await self.storage.auth_key(), + await self.storage.test_mode(), is_media=True + ) + workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)] + queue = asyncio.Queue(1) + + try: + await session.start() + + fp.seek(part_size * file_part) + + while True: + chunk = fp.read(part_size) + + if not chunk: + if not is_big and not is_missing_part: + md5_sum = "".join([hex(i)[2:].zfill(2) for i in md5_sum.digest()]) + break + + if is_big: + rpc = raw.functions.upload.SaveBigFilePart( + file_id=file_id, + file_part=file_part, + file_total_parts=file_total_parts, + bytes=chunk + ) + else: + rpc = raw.functions.upload.SaveFilePart( + file_id=file_id, + file_part=file_part, + bytes=chunk + ) + + await queue.put(rpc) + + if is_missing_part: + return + + if not is_big and not is_missing_part: + md5_sum.update(chunk) + + file_part += 1 + + if progress: + func = functools.partial( + progress, + min(file_part * part_size, file_size), + file_size, + *progress_args + ) + + if inspect.iscoroutinefunction(progress): + await func() + else: + await self.loop.run_in_executor(self.executor, func) + except StopTransmission: + raise + except Exception as e: + log.exception(e) + else: + if is_big: + return raw.types.InputFileBig( + id=file_id, + parts=file_total_parts, + name=file_name, + + ) + else: + return raw.types.InputFile( + id=file_id, + parts=file_total_parts, + name=file_name, + md5_checksum=md5_sum + ) + finally: + for _ in workers: + await queue.put(None) + + await asyncio.gather(*workers) + + await session.stop() + + if isinstance(path, (str, PurePath)): + fp.close()