From dc6c816c80934a4f3e4953e7faca327b2585ac55 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Thu, 10 Feb 2022 06:44:42 +0100 Subject: [PATCH] Revert some of the last changes --- pyrogram/connection/connection.py | 31 +++++----- pyrogram/connection/transport/tcp/tcp.py | 2 +- pyrogram/session/auth.py | 2 +- pyrogram/session/session.py | 74 ++++++++++-------------- 4 files changed, 48 insertions(+), 61 deletions(-) diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 04578b0e..2173c70b 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -27,6 +27,8 @@ log = logging.getLogger(__name__) class Connection: + MAX_RETRIES = 3 + MODES = { 0: TCPFull, 1: TCPAbridged, @@ -35,7 +37,7 @@ class Connection: 4: TCPIntermediateO } - def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1): + def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = ipv6 @@ -45,18 +47,17 @@ class Connection: self.mode = self.MODES.get(mode, TCPAbridged) self.protocol = None # type: TCP - self.is_connected = asyncio.Event() async def connect(self): - while True: + for i in range(Connection.MAX_RETRIES): self.protocol = self.mode(self.ipv6, self.proxy) try: log.info("Connecting...") await self.protocol.connect(self.address) except OSError as e: - log.warning(f"Connection failed due to network issues: {e}") - await self.protocol.close() + log.warning(f"Unable to connect due to network issues: {e}") + self.protocol.close() await asyncio.sleep(1) else: log.info("Connected! {} DC{}{} - IPv{} - {}".format( @@ -67,21 +68,19 @@ class Connection: self.mode.__name__, )) break + else: + log.warning("Connection failed! Trying again...") + raise TimeoutError - self.is_connected.set() - - async def close(self): - await self.protocol.close() - self.is_connected.clear() + def close(self): + self.protocol.close() log.info("Disconnected") - async def reconnect(self): - await self.close() - await self.connect() - async def send(self, data: bytes): - await self.is_connected.wait() - await self.protocol.send(data) + try: + await self.protocol.send(data) + except Exception: + raise OSError async def recv(self) -> Optional[bytes]: return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 386ca329..0b858c02 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -82,7 +82,7 @@ class TCP: self.socket.connect(address) self.reader, self.writer = await asyncio.open_connection(sock=self.socket) - async def close(self): + def close(self): try: self.writer.close() except AttributeError: diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index f33ffba9..d4083b21 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -258,4 +258,4 @@ class Auth: else: return auth_key finally: - await self.connection.close() + self.connection.close() diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 89171ed4..05d1fd4a 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -83,6 +83,7 @@ class Session: self.stored_msg_ids = [] self.ping_task = None + self.ping_task_event = asyncio.Event() self.network_task = None @@ -149,23 +150,17 @@ class Session: async def stop(self): self.is_connected.clear() - if self.ping_task: - self.ping_task.cancel() + self.ping_task_event.set() - try: - await self.ping_task - except asyncio.CancelledError: - pass + if self.ping_task is not None: + await self.ping_task + + self.ping_task_event.clear() + + self.connection.close() if self.network_task: - self.network_task.cancel() - - try: - await self.network_task - except asyncio.CancelledError: - pass - - await self.connection.close() + await self.network_task for i in self.results.values(): i.event.set() @@ -194,7 +189,7 @@ class Session: self.stored_msg_ids ) except SecurityCheckMismatch: - await self.connection.close() + self.connection.close() return messages = ( @@ -252,53 +247,46 @@ class Session: self.pending_acks.clear() async def ping_worker(self): + log.info("PingTask started") + while True: + try: + await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL) + except asyncio.TimeoutError: + pass + else: + break + try: await self._send( raw.functions.PingDelayDisconnect( - ping_id=0, - disconnect_delay=self.WAIT_TIMEOUT + 10 + ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 ), False ) except (OSError, TimeoutError, RPCError): pass - await asyncio.sleep(self.PING_INTERVAL) + log.info("PingTask stopped") async def network_worker(self): + log.info("NetworkTask started") + while True: packet = await self.connection.recv() - if not packet: - await self.connection.reconnect() + if packet is None or len(packet) == 4: + if packet: + log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') - try: - await self._send( - raw.functions.InvokeWithLayer( - layer=layer, - query=raw.functions.InitConnection( - api_id=self.client.api_id, - app_version=self.client.app_version, - device_model=self.client.device_model, - system_version=self.client.system_version, - system_lang_code=self.client.lang_code, - lang_code=self.client.lang_code, - lang_pack="", - query=raw.functions.help.GetConfig(), - ) - ), - wait_response=False - ) - except (OSError, TimeoutError, RPCError): - pass + if self.is_connected.is_set(): + self.loop.create_task(self.restart()) - continue - - if len(packet) == 4: - log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') + break self.loop.create_task(self.handle_packet(packet)) + log.info("NetworkTask stopped") + async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): message = self.msg_factory(data) msg_id = message.msg_id