diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 2173c70b..62e68751 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -37,7 +37,7 @@ class Connection: 4: TCPIntermediateO } - def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): + def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1): self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = ipv6 @@ -47,6 +47,7 @@ class Connection: self.mode = self.MODES.get(mode, TCPAbridged) self.protocol = None # type: TCP + self.is_connected = asyncio.Event() async def connect(self): for i in range(Connection.MAX_RETRIES): @@ -56,8 +57,8 @@ class Connection: log.info("Connecting...") await self.protocol.connect(self.address) except OSError as e: - log.warning(f"Unable to connect due to network issues: {e}") - self.protocol.close() + log.warning(f"Connection failed due to network issues: {e}") + await self.protocol.close() await asyncio.sleep(1) else: log.info("Connected! {} DC{}{} - IPv{} - {}".format( @@ -69,18 +70,23 @@ class Connection: )) break else: - log.warning("Connection failed! Trying again...") + log.warning("Couldn't connect. Trying again...") raise TimeoutError - def close(self): - self.protocol.close() + self.is_connected.set() + + async def close(self): + await self.protocol.close() + self.is_connected.clear() log.info("Disconnected") + async def reconnect(self): + await self.close() + await self.connect() + async def send(self, data: bytes): - try: - await self.protocol.send(data) - except Exception: - raise OSError + await self.is_connected.wait() + await self.protocol.send(data) async def recv(self) -> Optional[bytes]: return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 0b858c02..386ca329 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -82,7 +82,7 @@ class TCP: self.socket.connect(address) self.reader, self.writer = await asyncio.open_connection(sock=self.socket) - def close(self): + async def close(self): try: self.writer.close() except AttributeError: diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 05d1fd4a..7ee5d6ba 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -83,7 +83,6 @@ class Session: self.stored_msg_ids = [] self.ping_task = None - self.ping_task_event = asyncio.Event() self.network_task = None @@ -150,17 +149,23 @@ class Session: async def stop(self): self.is_connected.clear() - self.ping_task_event.set() + if self.ping_task: + self.ping_task.cancel() - if self.ping_task is not None: - await self.ping_task - - self.ping_task_event.clear() - - self.connection.close() + try: + await self.ping_task + except asyncio.CancelledError: + pass if self.network_task: - await self.network_task + self.network_task.cancel() + + try: + await self.network_task + except asyncio.CancelledError: + pass + + await self.connection.close() for i in self.results.values(): i.event.set() @@ -189,7 +194,7 @@ class Session: self.stored_msg_ids ) except SecurityCheckMismatch: - self.connection.close() + await self.connection.close() return messages = ( @@ -247,46 +252,32 @@ class Session: self.pending_acks.clear() async def ping_worker(self): - log.info("PingTask started") - while True: - try: - await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL) - except asyncio.TimeoutError: - pass - else: - break - try: await self._send( raw.functions.PingDelayDisconnect( - ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 + ping_id=0, + disconnect_delay=self.WAIT_TIMEOUT + 10 ), False ) except (OSError, TimeoutError, RPCError): pass - log.info("PingTask stopped") + await asyncio.sleep(self.PING_INTERVAL) async def network_worker(self): - log.info("NetworkTask started") - while True: packet = await self.connection.recv() - if packet is None or len(packet) == 4: - if packet: - log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') + if not packet: + await self.connection.reconnect() + continue - if self.is_connected.is_set(): - self.loop.create_task(self.restart()) - - break + if len(packet) == 4: + log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') self.loop.create_task(self.handle_packet(packet)) - log.info("NetworkTask stopped") - async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): message = self.msg_factory(data) msg_id = message.msg_id