From bff583ed7562c20c3f1818a3cf05bf6051987320 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Mon, 26 Dec 2022 22:26:55 +0100 Subject: [PATCH] Revert some of the latest changes --- pyrogram/client.py | 25 --------- pyrogram/connection/connection.py | 13 +++-- pyrogram/connection/transport/tcp/tcp.py | 64 ++++++++++-------------- pyrogram/methods/auth/initialize.py | 3 -- pyrogram/methods/auth/terminate.py | 7 --- pyrogram/session/auth.py | 2 +- pyrogram/session/internals/seq_no.py | 12 +++-- pyrogram/session/session.py | 36 +++++++------ pyrogram/storage/file_storage.py | 9 ++-- pyrogram/storage/sqlite_storage.py | 26 ++++++---- 10 files changed, 86 insertions(+), 111 deletions(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index 36ab4e4c..63e4b472 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -26,7 +26,6 @@ import re import shutil import sys from concurrent.futures.thread import ThreadPoolExecutor -from datetime import datetime, timedelta from hashlib import sha256 from importlib import import_module from io import StringIO, BytesIO @@ -186,9 +185,6 @@ class Client(Methods): WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None WORKDIR = PARENT_DIR - # Interval of seconds in which the updates watchdog will kick in - UPDATES_WATCHDOG_INTERVAL = 5 * 60 - mimetypes = MimeTypes() mimetypes.readfp(StringIO(mime_types)) @@ -277,13 +273,6 @@ class Client(Methods): self.message_cache = Cache(10000) - # Sometimes, for some reason, the server will stop sending updates and will only respond to pings. - # This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again - # after some idle time has been detected. - self.updates_watchdog_task = None - self.updates_watchdog_event = asyncio.Event() - self.last_update_time = datetime.now() - self.loop = asyncio.get_event_loop() def __enter__(self): @@ -304,18 +293,6 @@ class Client(Methods): except ConnectionError: pass - async def updates_watchdog(self): - while True: - try: - await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL) - except asyncio.TimeoutError: - pass - else: - break - - if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL): - await self.invoke(raw.functions.updates.GetState()) - async def authorize(self) -> User: if self.bot_token: return await self.sign_in_bot(self.bot_token) @@ -508,8 +485,6 @@ class Client(Methods): return is_min async def handle_updates(self, updates): - self.last_update_time = datetime.now() - if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)): is_min = any(( await self.fetch_peers(updates.users), diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 69cbb813..051d3c52 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -48,7 +48,7 @@ class Connection: await self.protocol.connect(self.address) except OSError as e: log.warning("Unable to connect due to network issues: %s", e) - await self.protocol.close() + self.protocol.close() await asyncio.sleep(1) else: log.info("Connected! %s DC%s%s - IPv%s", @@ -59,14 +59,17 @@ class Connection: break else: log.warning("Connection failed! Trying again...") - raise ConnectionError + raise TimeoutError - async def close(self): - await self.protocol.close() + def close(self): + self.protocol.close() log.info("Disconnected") async def send(self, data: bytes): - await self.protocol.send(data) + try: + await self.protocol.send(data) + except Exception as e: + raise OSError(e) async def recv(self) -> Optional[bytes]: return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 13b6e7de..beb2e58a 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -20,6 +20,9 @@ import asyncio import ipaddress import logging import socket +import time +from concurrent.futures import ThreadPoolExecutor + import socks log = logging.getLogger(__name__) @@ -31,12 +34,10 @@ class TCP: def __init__(self, ipv6: bool, proxy: dict): self.socket = None - self.reader = None - self.writer = None - - self.send_queue = asyncio.Queue() - self.send_task = None + self.reader = None # type: asyncio.StreamReader + self.writer = None # type: asyncio.StreamWriter + self.lock = asyncio.Lock() self.loop = asyncio.get_event_loop() if proxy: @@ -62,50 +63,39 @@ class TCP: log.info("Using proxy %s", hostname) else: - self.socket = socket.socket( + self.socket = socks.socksocket( socket.AF_INET6 if ipv6 else socket.AF_INET ) - self.socket.setblocking(False) + self.socket.settimeout(TCP.TIMEOUT) async def connect(self, address: tuple): - try: - await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT) - except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 - raise TimeoutError("Connection timed out") + # The socket used by the whole logic is blocking and thus it blocks when connecting. + # Offload the task to a thread executor to avoid blocking the main event loop. + with ThreadPoolExecutor(1) as executor: + await self.loop.run_in_executor(executor, self.socket.connect, address) self.reader, self.writer = await asyncio.open_connection(sock=self.socket) - self.send_task = asyncio.create_task(self.send_worker()) - - async def close(self): - await self.send_queue.put(None) - - if self.send_task is not None: - await self.send_task + def close(self): try: - if self.writer is not None: - self.writer.close() - await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) - except Exception as e: - log.info("Close exception: %s %s", type(e).__name__, e) + self.writer.close() + except AttributeError: + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError: + pass + finally: + # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout. + # This is a workaround that seems to fix the occasional delayed stop of a client. + time.sleep(0.001) + self.socket.close() async def send(self, data: bytes): - await self.send_queue.put(data) - - async def send_worker(self): - while True: - data = await self.send_queue.get() - - if data is None: - break - - try: - self.writer.write(data) - await self.writer.drain() - except Exception as e: - log.info("Send exception: %s %s", type(e).__name__, e) + async with self.lock: + self.writer.write(data) + await self.writer.drain() async def recv(self, length: int = 0): data = b"" diff --git a/pyrogram/methods/auth/initialize.py b/pyrogram/methods/auth/initialize.py index 7188b668..1e7915e0 100644 --- a/pyrogram/methods/auth/initialize.py +++ b/pyrogram/methods/auth/initialize.py @@ -16,7 +16,6 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -import asyncio import logging import pyrogram @@ -47,6 +46,4 @@ class Initialize: await self.dispatcher.start() - self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog()) - self.is_initialized = True diff --git a/pyrogram/methods/auth/terminate.py b/pyrogram/methods/auth/terminate.py index 70cfc80e..5ecb6758 100644 --- a/pyrogram/methods/auth/terminate.py +++ b/pyrogram/methods/auth/terminate.py @@ -51,11 +51,4 @@ class Terminate: self.media_sessions.clear() - self.updates_watchdog_event.set() - - if self.updates_watchdog_task is not None: - await self.updates_watchdog_task - - self.updates_watchdog_event.clear() - self.is_initialized = False diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index c5d9cd9a..d51e18f8 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -278,4 +278,4 @@ class Auth: else: return auth_key finally: - await self.connection.close() + self.connection.close() diff --git a/pyrogram/session/internals/seq_no.py b/pyrogram/session/internals/seq_no.py index 79501d98..0abc4a2f 100644 --- a/pyrogram/session/internals/seq_no.py +++ b/pyrogram/session/internals/seq_no.py @@ -16,15 +16,19 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +from threading import Lock + class SeqNo: def __init__(self): self.content_related_messages_sent = 0 + self.lock = Lock() def __call__(self, is_content_related: bool) -> int: - seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) + with self.lock: + seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) - if is_content_related: - self.content_related_messages_sent += 1 + if is_content_related: + self.content_related_messages_sent += 1 - return seq_no + return seq_no diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 3899aa52..5135af69 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -156,11 +156,14 @@ class Session: self.ping_task_event.clear() - await self.connection.close() + self.connection.close() if self.recv_task: await self.recv_task + for i in self.results.values(): + i.event.set() + if not self.is_media and callable(self.client.disconnect_handler): try: await self.client.disconnect_handler(self.client) @@ -185,8 +188,7 @@ class Session: self.stored_msg_ids ) except SecurityCheckMismatch as e: - log.info("Discarding packet: %s", e) - await self.connection.close() + log.warning("Discarding packet: %s", e) return messages = ( @@ -282,6 +284,9 @@ class Session: message = self.msg_factory(data) msg_id = message.msg_id + if wait_response: + self.results[msg_id] = Result() + log.debug("Sent: %s", message) payload = await self.loop.run_in_executor( @@ -294,35 +299,34 @@ class Session: self.auth_key_id ) - await self.connection.send(payload) + try: + await self.connection.send(payload) + except OSError as e: + self.results.pop(msg_id, None) + raise e if wait_response: - self.results[msg_id] = Result() - try: await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) except asyncio.TimeoutError: pass - - result = self.results.pop(msg_id).value + finally: + result = self.results.pop(msg_id).value if result is None: raise TimeoutError("Request timed out") - - if isinstance(result, raw.types.RpcError): + elif isinstance(result, raw.types.RpcError): if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)): data = data.query RPCError.raise_it(result, type(data)) - - if isinstance(result, raw.types.BadMsgNotification): + elif isinstance(result, raw.types.BadMsgNotification): raise BadMsgNotification(result.error_code) - - if isinstance(result, raw.types.BadServerSalt): + elif isinstance(result, raw.types.BadServerSalt): self.salt = result.new_server_salt return await self.send(data, wait_response, timeout) - - return result + else: + return result async def invoke( self, diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py index aebe9176..986787cd 100644 --- a/pyrogram/storage/file_storage.py +++ b/pyrogram/storage/file_storage.py @@ -38,13 +38,13 @@ class FileStorage(SQLiteStorage): version = self.version() if version == 1: - with self.conn: + with self.lock, self.conn: self.conn.execute("DELETE FROM peers") version += 1 if version == 2: - with self.conn: + with self.lock, self.conn: self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER") version += 1 @@ -63,7 +63,10 @@ class FileStorage(SQLiteStorage): self.update() with self.conn: - self.conn.execute("VACUUM") + try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum + self.conn.execute("VACUUM") + except sqlite3.OperationalError: + pass async def delete(self): os.remove(self.database) diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index e28b9b74..15e5ddc0 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -19,6 +19,7 @@ import inspect import sqlite3 import time +from threading import Lock from typing import List, Tuple, Any from pyrogram import raw @@ -97,9 +98,10 @@ class SQLiteStorage(Storage): super().__init__(name) self.conn = None # type: sqlite3.Connection + self.lock = Lock() def create(self): - with self.conn: + with self.lock, self.conn: self.conn.executescript(SCHEMA) self.conn.execute( @@ -117,20 +119,24 @@ class SQLiteStorage(Storage): async def save(self): await self.date(int(time.time())) - self.conn.commit() + + with self.lock: + self.conn.commit() async def close(self): - self.conn.close() + with self.lock: + self.conn.close() async def delete(self): raise NotImplementedError async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): - self.conn.executemany( - "REPLACE INTO peers (id, access_hash, type, username, phone_number)" - "VALUES (?, ?, ?, ?, ?)", - peers - ) + with self.lock: + self.conn.executemany( + "REPLACE INTO peers (id, access_hash, type, username, phone_number)" + "VALUES (?, ?, ?, ?, ?)", + peers + ) async def get_peer_by_id(self, peer_id: int): r = self.conn.execute( @@ -179,7 +185,7 @@ class SQLiteStorage(Storage): def _set(self, value: Any): attr = inspect.stack()[2].function - with self.conn: + with self.lock, self.conn: self.conn.execute( f"UPDATE sessions SET {attr} = ?", (value,) @@ -215,7 +221,7 @@ class SQLiteStorage(Storage): "SELECT number FROM version" ).fetchone()[0] else: - with self.conn: + with self.lock, self.conn: self.conn.execute( "UPDATE version SET number = ?", (value,)