mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 12:57:52 +00:00
Improve stability in case of connection failures
This commit is contained in:
parent
89c49111b0
commit
462e5d11a5
@ -37,7 +37,7 @@ class Connection:
|
|||||||
4: TCPIntermediateO
|
4: TCPIntermediateO
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
|
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
|
||||||
self.dc_id = dc_id
|
self.dc_id = dc_id
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
self.ipv6 = ipv6
|
self.ipv6 = ipv6
|
||||||
@ -47,6 +47,7 @@ class Connection:
|
|||||||
self.mode = self.MODES.get(mode, TCPAbridged)
|
self.mode = self.MODES.get(mode, TCPAbridged)
|
||||||
|
|
||||||
self.protocol = None # type: TCP
|
self.protocol = None # type: TCP
|
||||||
|
self.is_connected = asyncio.Event()
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
for i in range(Connection.MAX_RETRIES):
|
for i in range(Connection.MAX_RETRIES):
|
||||||
@ -56,8 +57,8 @@ class Connection:
|
|||||||
log.info("Connecting...")
|
log.info("Connecting...")
|
||||||
await self.protocol.connect(self.address)
|
await self.protocol.connect(self.address)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
log.warning(f"Unable to connect due to network issues: {e}")
|
log.warning(f"Connection failed due to network issues: {e}")
|
||||||
self.protocol.close()
|
await self.protocol.close()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
else:
|
else:
|
||||||
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
|
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
|
||||||
@ -69,18 +70,23 @@ class Connection:
|
|||||||
))
|
))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
log.warning("Connection failed! Trying again...")
|
log.warning("Couldn't connect. Trying again...")
|
||||||
raise TimeoutError
|
raise TimeoutError
|
||||||
|
|
||||||
def close(self):
|
self.is_connected.set()
|
||||||
self.protocol.close()
|
|
||||||
|
async def close(self):
|
||||||
|
await self.protocol.close()
|
||||||
|
self.is_connected.clear()
|
||||||
log.info("Disconnected")
|
log.info("Disconnected")
|
||||||
|
|
||||||
|
async def reconnect(self):
|
||||||
|
await self.close()
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
async def send(self, data: bytes):
|
async def send(self, data: bytes):
|
||||||
try:
|
await self.is_connected.wait()
|
||||||
await self.protocol.send(data)
|
await self.protocol.send(data)
|
||||||
except Exception:
|
|
||||||
raise OSError
|
|
||||||
|
|
||||||
async def recv(self) -> Optional[bytes]:
|
async def recv(self) -> Optional[bytes]:
|
||||||
return await self.protocol.recv()
|
return await self.protocol.recv()
|
||||||
|
@ -82,7 +82,7 @@ class TCP:
|
|||||||
self.socket.connect(address)
|
self.socket.connect(address)
|
||||||
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
|
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
|
||||||
|
|
||||||
def close(self):
|
async def close(self):
|
||||||
try:
|
try:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -83,7 +83,6 @@ class Session:
|
|||||||
self.stored_msg_ids = []
|
self.stored_msg_ids = []
|
||||||
|
|
||||||
self.ping_task = None
|
self.ping_task = None
|
||||||
self.ping_task_event = asyncio.Event()
|
|
||||||
|
|
||||||
self.network_task = None
|
self.network_task = None
|
||||||
|
|
||||||
@ -150,17 +149,23 @@ class Session:
|
|||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.is_connected.clear()
|
self.is_connected.clear()
|
||||||
|
|
||||||
self.ping_task_event.set()
|
if self.ping_task:
|
||||||
|
self.ping_task.cancel()
|
||||||
|
|
||||||
if self.ping_task is not None:
|
try:
|
||||||
await self.ping_task
|
await self.ping_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
self.ping_task_event.clear()
|
pass
|
||||||
|
|
||||||
self.connection.close()
|
|
||||||
|
|
||||||
if self.network_task:
|
if self.network_task:
|
||||||
await self.network_task
|
self.network_task.cancel()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.network_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self.connection.close()
|
||||||
|
|
||||||
for i in self.results.values():
|
for i in self.results.values():
|
||||||
i.event.set()
|
i.event.set()
|
||||||
@ -189,7 +194,7 @@ class Session:
|
|||||||
self.stored_msg_ids
|
self.stored_msg_ids
|
||||||
)
|
)
|
||||||
except SecurityCheckMismatch:
|
except SecurityCheckMismatch:
|
||||||
self.connection.close()
|
await self.connection.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
messages = (
|
messages = (
|
||||||
@ -247,46 +252,32 @@ class Session:
|
|||||||
self.pending_acks.clear()
|
self.pending_acks.clear()
|
||||||
|
|
||||||
async def ping_worker(self):
|
async def ping_worker(self):
|
||||||
log.info("PingTask started")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send(
|
await self._send(
|
||||||
raw.functions.PingDelayDisconnect(
|
raw.functions.PingDelayDisconnect(
|
||||||
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
|
ping_id=0,
|
||||||
|
disconnect_delay=self.WAIT_TIMEOUT + 10
|
||||||
), False
|
), False
|
||||||
)
|
)
|
||||||
except (OSError, TimeoutError, RPCError):
|
except (OSError, TimeoutError, RPCError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.info("PingTask stopped")
|
await asyncio.sleep(self.PING_INTERVAL)
|
||||||
|
|
||||||
async def network_worker(self):
|
async def network_worker(self):
|
||||||
log.info("NetworkTask started")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
packet = await self.connection.recv()
|
packet = await self.connection.recv()
|
||||||
|
|
||||||
if packet is None or len(packet) == 4:
|
if not packet:
|
||||||
if packet:
|
await self.connection.reconnect()
|
||||||
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
continue
|
||||||
|
|
||||||
if self.is_connected.is_set():
|
if len(packet) == 4:
|
||||||
self.loop.create_task(self.restart())
|
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
self.loop.create_task(self.handle_packet(packet))
|
self.loop.create_task(self.handle_packet(packet))
|
||||||
|
|
||||||
log.info("NetworkTask stopped")
|
|
||||||
|
|
||||||
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
|
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
|
||||||
message = self.msg_factory(data)
|
message = self.msg_factory(data)
|
||||||
msg_id = message.msg_id
|
msg_id = message.msg_id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user