2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Revert some of the last changes

This commit is contained in:
Dan 2022-02-10 06:44:42 +01:00
parent 0d11240740
commit dc6c816c80
4 changed files with 48 additions and 61 deletions

View File

@ -27,6 +27,8 @@ log = logging.getLogger(__name__)
class Connection:
MAX_RETRIES = 3
MODES = {
0: TCPFull,
1: TCPAbridged,
@ -35,7 +37,7 @@ class Connection:
4: TCPIntermediateO
}
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
@ -45,18 +47,17 @@ class Connection:
self.mode = self.MODES.get(mode, TCPAbridged)
self.protocol = None # type: TCP
self.is_connected = asyncio.Event()
async def connect(self):
while True:
for i in range(Connection.MAX_RETRIES):
self.protocol = self.mode(self.ipv6, self.proxy)
try:
log.info("Connecting...")
await self.protocol.connect(self.address)
except OSError as e:
log.warning(f"Connection failed due to network issues: {e}")
await self.protocol.close()
log.warning(f"Unable to connect due to network issues: {e}")
self.protocol.close()
await asyncio.sleep(1)
else:
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@ -67,21 +68,19 @@ class Connection:
self.mode.__name__,
))
break
else:
log.warning("Connection failed! Trying again...")
raise TimeoutError
self.is_connected.set()
async def close(self):
await self.protocol.close()
self.is_connected.clear()
def close(self):
self.protocol.close()
log.info("Disconnected")
async def reconnect(self):
await self.close()
await self.connect()
async def send(self, data: bytes):
await self.is_connected.wait()
await self.protocol.send(data)
try:
await self.protocol.send(data)
except Exception:
raise OSError
async def recv(self) -> Optional[bytes]:
return await self.protocol.recv()

View File

@ -82,7 +82,7 @@ class TCP:
self.socket.connect(address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
async def close(self):
def close(self):
try:
self.writer.close()
except AttributeError:

View File

@ -258,4 +258,4 @@ class Auth:
else:
return auth_key
finally:
await self.connection.close()
self.connection.close()

View File

@ -83,6 +83,7 @@ class Session:
self.stored_msg_ids = []
self.ping_task = None
self.ping_task_event = asyncio.Event()
self.network_task = None
@ -149,23 +150,17 @@ class Session:
async def stop(self):
self.is_connected.clear()
if self.ping_task:
self.ping_task.cancel()
self.ping_task_event.set()
try:
await self.ping_task
except asyncio.CancelledError:
pass
if self.ping_task is not None:
await self.ping_task
self.ping_task_event.clear()
self.connection.close()
if self.network_task:
self.network_task.cancel()
try:
await self.network_task
except asyncio.CancelledError:
pass
await self.connection.close()
await self.network_task
for i in self.results.values():
i.event.set()
@ -194,7 +189,7 @@ class Session:
self.stored_msg_ids
)
except SecurityCheckMismatch:
await self.connection.close()
self.connection.close()
return
messages = (
@ -252,53 +247,46 @@ class Session:
self.pending_acks.clear()
async def ping_worker(self):
log.info("PingTask started")
while True:
try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
pass
else:
break
try:
await self._send(
raw.functions.PingDelayDisconnect(
ping_id=0,
disconnect_delay=self.WAIT_TIMEOUT + 10
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False
)
except (OSError, TimeoutError, RPCError):
pass
await asyncio.sleep(self.PING_INTERVAL)
log.info("PingTask stopped")
async def network_worker(self):
log.info("NetworkTask started")
while True:
packet = await self.connection.recv()
if not packet:
await self.connection.reconnect()
if packet is None or len(packet) == 4:
if packet:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
try:
await self._send(
raw.functions.InvokeWithLayer(
layer=layer,
query=raw.functions.InitConnection(
api_id=self.client.api_id,
app_version=self.client.app_version,
device_model=self.client.device_model,
system_version=self.client.system_version,
system_lang_code=self.client.lang_code,
lang_code=self.client.lang_code,
lang_pack="",
query=raw.functions.help.GetConfig(),
)
),
wait_response=False
)
except (OSError, TimeoutError, RPCError):
pass
if self.is_connected.is_set():
self.loop.create_task(self.restart())
continue
if len(packet) == 4:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
break
self.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
message = self.msg_factory(data)
msg_id = message.msg_id