diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 73c2312f..69cbb813 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -20,67 +20,53 @@ import asyncio import logging from typing import Optional -from .transport import * +from .transport import TCP, TCPAbridgedO from ..session.internals import DataCenter log = logging.getLogger(__name__) class Connection: - MAX_RETRIES = 3 + MAX_CONNECTION_ATTEMPTS = 3 - MODES = { - 0: TCPFull, - 1: TCPAbridged, - 2: TCPIntermediate, - 3: TCPAbridgedO, - 4: TCPIntermediateO - } - - def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): + def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False): self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = ipv6 self.proxy = proxy self.media = media - self.address = DataCenter(dc_id, test_mode, ipv6, media) - self.mode = self.MODES.get(mode, TCPAbridged) - self.protocol = None # type: TCP + self.address = DataCenter(dc_id, test_mode, ipv6, media) + self.protocol: TCP = None async def connect(self): - for i in range(Connection.MAX_RETRIES): - self.protocol = self.mode(self.ipv6, self.proxy) + for i in range(Connection.MAX_CONNECTION_ATTEMPTS): + self.protocol = TCPAbridgedO(self.ipv6, self.proxy) try: log.info("Connecting...") await self.protocol.connect(self.address) except OSError as e: - log.warning(f"Unable to connect due to network issues: {e}") + log.warning("Unable to connect due to network issues: %s", e) await self.protocol.close() await asyncio.sleep(1) else: - log.info("Connected! {} DC{}{} - IPv{} - {}".format( - "Test" if self.test_mode else "Production", - self.dc_id, - " (media)" if self.media else "", - "6" if self.ipv6 else "4", - self.mode.__name__, - )) + log.info("Connected! %s DC%s%s - IPv%s", + "Test" if self.test_mode else "Production", + self.dc_id, + " (media)" if self.media else "", + "6" if self.ipv6 else "4") break else: log.warning("Connection failed! Trying again...") - raise TimeoutError + raise ConnectionError async def close(self): await self.protocol.close() log.info("Disconnected") async def send(self, data: bytes): - try: - await self.protocol.send(data) - except Exception as e: - raise OSError(e) + await self.protocol.send(data) async def recv(self) -> Optional[bytes]: return await self.protocol.recv()