2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 04:48:06 +00:00

Improve stability in case of connection failures

This commit is contained in:
Dan 2022-02-10 01:05:36 +01:00
parent 89c49111b0
commit 462e5d11a5
3 changed files with 40 additions and 43 deletions

View File

@ -37,7 +37,7 @@ class Connection:
4: TCPIntermediateO
}
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
@ -47,6 +47,7 @@ class Connection:
self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP
self.is_connected = asyncio.Event()
async def connect(self):
for i in range(Connection.MAX_RETRIES):
@ -56,8 +57,8 @@ class Connection:
log.info("Connecting...")
await self.protocol.connect(self.address)
except OSError as e:
log.warning(f"Unable to connect due to network issues: {e}")
self.protocol.close()
log.warning(f"Connection failed due to network issues: {e}")
await self.protocol.close()
await asyncio.sleep(1)
else:
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@ -69,18 +70,23 @@ class Connection:
))
break
else:
log.warning("Connection failed! Trying again...")
log.warning("Couldn't connect. Trying again...")
raise TimeoutError
def close(self):
self.protocol.close()
self.is_connected.set()
async def close(self):
await self.protocol.close()
self.is_connected.clear()
log.info("Disconnected")
async def reconnect(self):
await self.close()
await self.connect()
async def send(self, data: bytes):
try:
await self.protocol.send(data)
except Exception:
raise OSError
await self.is_connected.wait()
await self.protocol.send(data)
async def recv(self) -> Optional[bytes]:
return await self.protocol.recv()

View File

@ -82,7 +82,7 @@ class TCP:
self.socket.connect(address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
def close(self):
async def close(self):
try:
self.writer.close()
except AttributeError:

View File

@ -83,7 +83,6 @@ class Session:
self.stored_msg_ids = []
self.ping_task = None
self.ping_task_event = asyncio.Event()
self.network_task = None
@ -150,17 +149,23 @@ class Session:
async def stop(self):
self.is_connected.clear()
self.ping_task_event.set()
if self.ping_task:
self.ping_task.cancel()
if self.ping_task is not None:
await self.ping_task
self.ping_task_event.clear()
self.connection.close()
try:
await self.ping_task
except asyncio.CancelledError:
pass
if self.network_task:
await self.network_task
self.network_task.cancel()
try:
await self.network_task
except asyncio.CancelledError:
pass
await self.connection.close()
for i in self.results.values():
i.event.set()
@ -189,7 +194,7 @@ class Session:
self.stored_msg_ids
)
except SecurityCheckMismatch:
self.connection.close()
await self.connection.close()
return
messages = (
@ -247,46 +252,32 @@ class Session:
self.pending_acks.clear()
async def ping_worker(self):
log.info("PingTask started")
while True:
try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
pass
else:
break
try:
await self._send(
raw.functions.PingDelayDisconnect(
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
ping_id=0,
disconnect_delay=self.WAIT_TIMEOUT + 10
), False
)
except (OSError, TimeoutError, RPCError):
pass
log.info("PingTask stopped")
await asyncio.sleep(self.PING_INTERVAL)
async def network_worker(self):
log.info("NetworkTask started")
while True:
packet = await self.connection.recv()
if packet is None or len(packet) == 4:
if packet:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
if not packet:
await self.connection.reconnect()
continue
if self.is_connected.is_set():
self.loop.create_task(self.restart())
break
if len(packet) == 4:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
self.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data)
msg_id = message.msg_id