2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Improve stability in case of connection failures

This commit is contained in:
Dan 2022-02-10 01:05:36 +01:00
parent 89c49111b0
commit 462e5d11a5
3 changed files with 40 additions and 43 deletions

View File

@ -37,7 +37,7 @@ class Connection:
4: TCPIntermediateO 4: TCPIntermediateO
} }
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = ipv6 self.ipv6 = ipv6
@ -47,6 +47,7 @@ class Connection:
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP self.protocol = None # type: TCP
self.is_connected = asyncio.Event()
async def connect(self): async def connect(self):
for i in range(Connection.MAX_RETRIES): for i in range(Connection.MAX_RETRIES):
@ -56,8 +57,8 @@ class Connection:
log.info("Connecting...") log.info("Connecting...")
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
log.warning(f"Unable to connect due to network issues: {e}") log.warning(f"Connection failed due to network issues: {e}")
self.protocol.close() await self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! {} DC{}{} - IPv{} - {}".format( log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@ -69,18 +70,23 @@ class Connection:
)) ))
break break
else: else:
log.warning("Connection failed! Trying again...") log.warning("Couldn't connect. Trying again...")
raise TimeoutError raise TimeoutError
def close(self): self.is_connected.set()
self.protocol.close()
async def close(self):
await self.protocol.close()
self.is_connected.clear()
log.info("Disconnected") log.info("Disconnected")
async def reconnect(self):
await self.close()
await self.connect()
async def send(self, data: bytes): async def send(self, data: bytes):
try: await self.is_connected.wait()
await self.protocol.send(data) await self.protocol.send(data)
except Exception:
raise OSError
async def recv(self) -> Optional[bytes]: async def recv(self) -> Optional[bytes]:
return await self.protocol.recv() return await self.protocol.recv()

View File

@ -82,7 +82,7 @@ class TCP:
self.socket.connect(address) self.socket.connect(address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket) self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
def close(self): async def close(self):
try: try:
self.writer.close() self.writer.close()
except AttributeError: except AttributeError:

View File

@ -83,7 +83,6 @@ class Session:
self.stored_msg_ids = [] self.stored_msg_ids = []
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event()
self.network_task = None self.network_task = None
@ -150,17 +149,23 @@ class Session:
async def stop(self): async def stop(self):
self.is_connected.clear() self.is_connected.clear()
self.ping_task_event.set() if self.ping_task:
self.ping_task.cancel()
if self.ping_task is not None: try:
await self.ping_task await self.ping_task
except asyncio.CancelledError:
self.ping_task_event.clear() pass
self.connection.close()
if self.network_task: if self.network_task:
self.network_task.cancel()
try:
await self.network_task await self.network_task
except asyncio.CancelledError:
pass
await self.connection.close()
for i in self.results.values(): for i in self.results.values():
i.event.set() i.event.set()
@ -189,7 +194,7 @@ class Session:
self.stored_msg_ids self.stored_msg_ids
) )
except SecurityCheckMismatch: except SecurityCheckMismatch:
self.connection.close() await self.connection.close()
return return
messages = ( messages = (
@ -247,46 +252,32 @@ class Session:
self.pending_acks.clear() self.pending_acks.clear()
async def ping_worker(self): async def ping_worker(self):
log.info("PingTask started")
while True: while True:
try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
pass
else:
break
try: try:
await self._send( await self._send(
raw.functions.PingDelayDisconnect( raw.functions.PingDelayDisconnect(
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 ping_id=0,
disconnect_delay=self.WAIT_TIMEOUT + 10
), False ), False
) )
except (OSError, TimeoutError, RPCError): except (OSError, TimeoutError, RPCError):
pass pass
log.info("PingTask stopped") await asyncio.sleep(self.PING_INTERVAL)
async def network_worker(self): async def network_worker(self):
log.info("NetworkTask started")
while True: while True:
packet = await self.connection.recv() packet = await self.connection.recv()
if packet is None or len(packet) == 4: if not packet:
if packet: await self.connection.reconnect()
continue
if len(packet) == 4:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
if self.is_connected.is_set():
self.loop.create_task(self.restart())
break
self.loop.create_task(self.handle_packet(packet)) self.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)
msg_id = message.msg_id msg_id = message.msg_id