2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Revert some of the last changes

This commit is contained in:
Dan 2022-02-10 06:44:42 +01:00
parent 0d11240740
commit dc6c816c80
4 changed files with 48 additions and 61 deletions

View File

@ -27,6 +27,8 @@ log = logging.getLogger(__name__)
class Connection: class Connection:
MAX_RETRIES = 3
MODES = { MODES = {
0: TCPFull, 0: TCPFull,
1: TCPAbridged, 1: TCPAbridged,
@ -35,7 +37,7 @@ class Connection:
4: TCPIntermediateO 4: TCPIntermediateO
} }
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1): def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = ipv6 self.ipv6 = ipv6
@ -45,18 +47,17 @@ class Connection:
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP self.protocol = None # type: TCP
self.is_connected = asyncio.Event()
async def connect(self): async def connect(self):
while True: for i in range(Connection.MAX_RETRIES):
self.protocol = self.mode(self.ipv6, self.proxy) self.protocol = self.mode(self.ipv6, self.proxy)
try: try:
log.info("Connecting...") log.info("Connecting...")
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
log.warning(f"Connection failed due to network issues: {e}") log.warning(f"Unable to connect due to network issues: {e}")
await self.protocol.close() self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! {} DC{}{} - IPv{} - {}".format( log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@ -67,21 +68,19 @@ class Connection:
self.mode.__name__, self.mode.__name__,
)) ))
break break
else:
log.warning("Connection failed! Trying again...")
raise TimeoutError
self.is_connected.set() def close(self):
self.protocol.close()
async def close(self):
await self.protocol.close()
self.is_connected.clear()
log.info("Disconnected") log.info("Disconnected")
async def reconnect(self):
await self.close()
await self.connect()
async def send(self, data: bytes): async def send(self, data: bytes):
await self.is_connected.wait() try:
await self.protocol.send(data) await self.protocol.send(data)
except Exception:
raise OSError
async def recv(self) -> Optional[bytes]: async def recv(self) -> Optional[bytes]:
return await self.protocol.recv() return await self.protocol.recv()

View File

@ -82,7 +82,7 @@ class TCP:
self.socket.connect(address) self.socket.connect(address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket) self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
async def close(self): def close(self):
try: try:
self.writer.close() self.writer.close()
except AttributeError: except AttributeError:

View File

@ -258,4 +258,4 @@ class Auth:
else: else:
return auth_key return auth_key
finally: finally:
await self.connection.close() self.connection.close()

View File

@ -83,6 +83,7 @@ class Session:
self.stored_msg_ids = [] self.stored_msg_ids = []
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event()
self.network_task = None self.network_task = None
@ -149,23 +150,17 @@ class Session:
async def stop(self): async def stop(self):
self.is_connected.clear() self.is_connected.clear()
if self.ping_task: self.ping_task_event.set()
self.ping_task.cancel()
try: if self.ping_task is not None:
await self.ping_task await self.ping_task
except asyncio.CancelledError:
pass self.ping_task_event.clear()
self.connection.close()
if self.network_task: if self.network_task:
self.network_task.cancel() await self.network_task
try:
await self.network_task
except asyncio.CancelledError:
pass
await self.connection.close()
for i in self.results.values(): for i in self.results.values():
i.event.set() i.event.set()
@ -194,7 +189,7 @@ class Session:
self.stored_msg_ids self.stored_msg_ids
) )
except SecurityCheckMismatch: except SecurityCheckMismatch:
await self.connection.close() self.connection.close()
return return
messages = ( messages = (
@ -252,53 +247,46 @@ class Session:
self.pending_acks.clear() self.pending_acks.clear()
async def ping_worker(self): async def ping_worker(self):
log.info("PingTask started")
while True: while True:
try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
pass
else:
break
try: try:
await self._send( await self._send(
raw.functions.PingDelayDisconnect( raw.functions.PingDelayDisconnect(
ping_id=0, ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
disconnect_delay=self.WAIT_TIMEOUT + 10
), False ), False
) )
except (OSError, TimeoutError, RPCError): except (OSError, TimeoutError, RPCError):
pass pass
await asyncio.sleep(self.PING_INTERVAL) log.info("PingTask stopped")
async def network_worker(self): async def network_worker(self):
log.info("NetworkTask started")
while True: while True:
packet = await self.connection.recv() packet = await self.connection.recv()
if not packet: if packet is None or len(packet) == 4:
await self.connection.reconnect() if packet:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
try: if self.is_connected.is_set():
await self._send( self.loop.create_task(self.restart())
raw.functions.InvokeWithLayer(
layer=layer,
query=raw.functions.InitConnection(
api_id=self.client.api_id,
app_version=self.client.app_version,
device_model=self.client.device_model,
system_version=self.client.system_version,
system_lang_code=self.client.lang_code,
lang_code=self.client.lang_code,
lang_pack="",
query=raw.functions.help.GetConfig(),
)
),
wait_response=False
)
except (OSError, TimeoutError, RPCError):
pass
continue break
if len(packet) == 4:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
self.loop.create_task(self.handle_packet(packet)) self.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data) message = self.msg_factory(data)
msg_id = message.msg_id msg_id = message.msg_id