mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 21:07:59 +00:00
Improve stability in case of connection failures
This commit is contained in:
parent
89c49111b0
commit
462e5d11a5
@ -37,7 +37,7 @@ class Connection:
|
||||
4: TCPIntermediateO
|
||||
}
|
||||
|
||||
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
|
||||
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = test_mode
|
||||
self.ipv6 = ipv6
|
||||
@ -47,6 +47,7 @@ class Connection:
|
||||
self.mode = self.MODES.get(mode, TCPAbridged)
|
||||
|
||||
self.protocol = None # type: TCP
|
||||
self.is_connected = asyncio.Event()
|
||||
|
||||
async def connect(self):
|
||||
for i in range(Connection.MAX_RETRIES):
|
||||
@ -56,8 +57,8 @@ class Connection:
|
||||
log.info("Connecting...")
|
||||
await self.protocol.connect(self.address)
|
||||
except OSError as e:
|
||||
log.warning(f"Unable to connect due to network issues: {e}")
|
||||
self.protocol.close()
|
||||
log.warning(f"Connection failed due to network issues: {e}")
|
||||
await self.protocol.close()
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
|
||||
@ -69,18 +70,23 @@ class Connection:
|
||||
))
|
||||
break
|
||||
else:
|
||||
log.warning("Connection failed! Trying again...")
|
||||
log.warning("Couldn't connect. Trying again...")
|
||||
raise TimeoutError
|
||||
|
||||
def close(self):
|
||||
self.protocol.close()
|
||||
self.is_connected.set()
|
||||
|
||||
async def close(self):
|
||||
await self.protocol.close()
|
||||
self.is_connected.clear()
|
||||
log.info("Disconnected")
|
||||
|
||||
async def reconnect(self):
|
||||
await self.close()
|
||||
await self.connect()
|
||||
|
||||
async def send(self, data: bytes):
|
||||
try:
|
||||
await self.protocol.send(data)
|
||||
except Exception:
|
||||
raise OSError
|
||||
await self.is_connected.wait()
|
||||
await self.protocol.send(data)
|
||||
|
||||
async def recv(self) -> Optional[bytes]:
|
||||
return await self.protocol.recv()
|
||||
|
@ -82,7 +82,7 @@ class TCP:
|
||||
self.socket.connect(address)
|
||||
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
|
||||
|
||||
def close(self):
|
||||
async def close(self):
|
||||
try:
|
||||
self.writer.close()
|
||||
except AttributeError:
|
||||
|
@ -83,7 +83,6 @@ class Session:
|
||||
self.stored_msg_ids = []
|
||||
|
||||
self.ping_task = None
|
||||
self.ping_task_event = asyncio.Event()
|
||||
|
||||
self.network_task = None
|
||||
|
||||
@ -150,17 +149,23 @@ class Session:
|
||||
async def stop(self):
|
||||
self.is_connected.clear()
|
||||
|
||||
self.ping_task_event.set()
|
||||
if self.ping_task:
|
||||
self.ping_task.cancel()
|
||||
|
||||
if self.ping_task is not None:
|
||||
await self.ping_task
|
||||
|
||||
self.ping_task_event.clear()
|
||||
|
||||
self.connection.close()
|
||||
try:
|
||||
await self.ping_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.network_task:
|
||||
await self.network_task
|
||||
self.network_task.cancel()
|
||||
|
||||
try:
|
||||
await self.network_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.connection.close()
|
||||
|
||||
for i in self.results.values():
|
||||
i.event.set()
|
||||
@ -189,7 +194,7 @@ class Session:
|
||||
self.stored_msg_ids
|
||||
)
|
||||
except SecurityCheckMismatch:
|
||||
self.connection.close()
|
||||
await self.connection.close()
|
||||
return
|
||||
|
||||
messages = (
|
||||
@ -247,46 +252,32 @@ class Session:
|
||||
self.pending_acks.clear()
|
||||
|
||||
async def ping_worker(self):
|
||||
log.info("PingTask started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
|
||||
try:
|
||||
await self._send(
|
||||
raw.functions.PingDelayDisconnect(
|
||||
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
|
||||
ping_id=0,
|
||||
disconnect_delay=self.WAIT_TIMEOUT + 10
|
||||
), False
|
||||
)
|
||||
except (OSError, TimeoutError, RPCError):
|
||||
pass
|
||||
|
||||
log.info("PingTask stopped")
|
||||
await asyncio.sleep(self.PING_INTERVAL)
|
||||
|
||||
async def network_worker(self):
|
||||
log.info("NetworkTask started")
|
||||
|
||||
while True:
|
||||
packet = await self.connection.recv()
|
||||
|
||||
if packet is None or len(packet) == 4:
|
||||
if packet:
|
||||
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
||||
if not packet:
|
||||
await self.connection.reconnect()
|
||||
continue
|
||||
|
||||
if self.is_connected.is_set():
|
||||
self.loop.create_task(self.restart())
|
||||
|
||||
break
|
||||
if len(packet) == 4:
|
||||
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
||||
|
||||
self.loop.create_task(self.handle_packet(packet))
|
||||
|
||||
log.info("NetworkTask stopped")
|
||||
|
||||
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
|
||||
message = self.msg_factory(data)
|
||||
msg_id = message.msg_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user