2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-25 11:28:05 +00:00

Revert some of the latest changes

This commit is contained in:
Dan 2022-12-26 22:26:55 +01:00
parent a81b8a2254
commit bff583ed75
10 changed files with 86 additions and 111 deletions

View File

@ -26,7 +26,6 @@ import re
import shutil import shutil
import sys import sys
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from importlib import import_module from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
@ -186,9 +185,6 @@ class Client(Methods):
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
WORKDIR = PARENT_DIR WORKDIR = PARENT_DIR
# Interval of seconds in which the updates watchdog will kick in
UPDATES_WATCHDOG_INTERVAL = 5 * 60
mimetypes = MimeTypes() mimetypes = MimeTypes()
mimetypes.readfp(StringIO(mime_types)) mimetypes.readfp(StringIO(mime_types))
@ -277,13 +273,6 @@ class Client(Methods):
self.message_cache = Cache(10000) self.message_cache = Cache(10000)
# Sometimes, for some reason, the server will stop sending updates and will only respond to pings.
# This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again
# after some idle time has been detected.
self.updates_watchdog_task = None
self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
def __enter__(self): def __enter__(self):
@ -304,18 +293,6 @@ class Client(Methods):
except ConnectionError: except ConnectionError:
pass pass
async def updates_watchdog(self):
while True:
try:
await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL)
except asyncio.TimeoutError:
pass
else:
break
if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL):
await self.invoke(raw.functions.updates.GetState())
async def authorize(self) -> User: async def authorize(self) -> User:
if self.bot_token: if self.bot_token:
return await self.sign_in_bot(self.bot_token) return await self.sign_in_bot(self.bot_token)
@ -508,8 +485,6 @@ class Client(Methods):
return is_min return is_min
async def handle_updates(self, updates): async def handle_updates(self, updates):
self.last_update_time = datetime.now()
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)): if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
is_min = any(( is_min = any((
await self.fetch_peers(updates.users), await self.fetch_peers(updates.users),

View File

@ -48,7 +48,7 @@ class Connection:
await self.protocol.connect(self.address) await self.protocol.connect(self.address)
except OSError as e: except OSError as e:
log.warning("Unable to connect due to network issues: %s", e) log.warning("Unable to connect due to network issues: %s", e)
await self.protocol.close() self.protocol.close()
await asyncio.sleep(1) await asyncio.sleep(1)
else: else:
log.info("Connected! %s DC%s%s - IPv%s", log.info("Connected! %s DC%s%s - IPv%s",
@ -59,14 +59,17 @@ class Connection:
break break
else: else:
log.warning("Connection failed! Trying again...") log.warning("Connection failed! Trying again...")
raise ConnectionError raise TimeoutError
async def close(self): def close(self):
await self.protocol.close() self.protocol.close()
log.info("Disconnected") log.info("Disconnected")
async def send(self, data: bytes): async def send(self, data: bytes):
await self.protocol.send(data) try:
await self.protocol.send(data)
except Exception as e:
raise OSError(e)
async def recv(self) -> Optional[bytes]: async def recv(self) -> Optional[bytes]:
return await self.protocol.recv() return await self.protocol.recv()

View File

@ -20,6 +20,9 @@ import asyncio
import ipaddress import ipaddress
import logging import logging
import socket import socket
import time
from concurrent.futures import ThreadPoolExecutor
import socks import socks
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -31,12 +34,10 @@ class TCP:
def __init__(self, ipv6: bool, proxy: dict): def __init__(self, ipv6: bool, proxy: dict):
self.socket = None self.socket = None
self.reader = None self.reader = None # type: asyncio.StreamReader
self.writer = None self.writer = None # type: asyncio.StreamWriter
self.send_queue = asyncio.Queue()
self.send_task = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
if proxy: if proxy:
@ -62,50 +63,39 @@ class TCP:
log.info("Using proxy %s", hostname) log.info("Using proxy %s", hostname)
else: else:
self.socket = socket.socket( self.socket = socks.socksocket(
socket.AF_INET6 if ipv6 socket.AF_INET6 if ipv6
else socket.AF_INET else socket.AF_INET
) )
self.socket.setblocking(False) self.socket.settimeout(TCP.TIMEOUT)
async def connect(self, address: tuple): async def connect(self, address: tuple):
try: # The socket used by the whole logic is blocking and thus it blocks when connecting.
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT) # Offload the task to a thread executor to avoid blocking the main event loop.
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 with ThreadPoolExecutor(1) as executor:
raise TimeoutError("Connection timed out") await self.loop.run_in_executor(executor, self.socket.connect, address)
self.reader, self.writer = await asyncio.open_connection(sock=self.socket) self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
self.send_task = asyncio.create_task(self.send_worker())
async def close(self):
await self.send_queue.put(None)
if self.send_task is not None:
await self.send_task
def close(self):
try: try:
if self.writer is not None: self.writer.close()
self.writer.close() except AttributeError:
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) try:
except Exception as e: self.socket.shutdown(socket.SHUT_RDWR)
log.info("Close exception: %s %s", type(e).__name__, e) except OSError:
pass
finally:
# A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
# This is a workaround that seems to fix the occasional delayed stop of a client.
time.sleep(0.001)
self.socket.close()
async def send(self, data: bytes): async def send(self, data: bytes):
await self.send_queue.put(data) async with self.lock:
self.writer.write(data)
async def send_worker(self): await self.writer.drain()
while True:
data = await self.send_queue.get()
if data is None:
break
try:
self.writer.write(data)
await self.writer.drain()
except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e)
async def recv(self, length: int = 0): async def recv(self, length: int = 0):
data = b"" data = b""

View File

@ -16,7 +16,6 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import pyrogram import pyrogram
@ -47,6 +46,4 @@ class Initialize:
await self.dispatcher.start() await self.dispatcher.start()
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
self.is_initialized = True self.is_initialized = True

View File

@ -51,11 +51,4 @@ class Terminate:
self.media_sessions.clear() self.media_sessions.clear()
self.updates_watchdog_event.set()
if self.updates_watchdog_task is not None:
await self.updates_watchdog_task
self.updates_watchdog_event.clear()
self.is_initialized = False self.is_initialized = False

View File

@ -278,4 +278,4 @@ class Auth:
else: else:
return auth_key return auth_key
finally: finally:
await self.connection.close() self.connection.close()

View File

@ -16,15 +16,19 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from threading import Lock
class SeqNo: class SeqNo:
def __init__(self): def __init__(self):
self.content_related_messages_sent = 0 self.content_related_messages_sent = 0
self.lock = Lock()
def __call__(self, is_content_related: bool) -> int: def __call__(self, is_content_related: bool) -> int:
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) with self.lock:
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
if is_content_related: if is_content_related:
self.content_related_messages_sent += 1 self.content_related_messages_sent += 1
return seq_no return seq_no

View File

@ -156,11 +156,14 @@ class Session:
self.ping_task_event.clear() self.ping_task_event.clear()
await self.connection.close() self.connection.close()
if self.recv_task: if self.recv_task:
await self.recv_task await self.recv_task
for i in self.results.values():
i.event.set()
if not self.is_media and callable(self.client.disconnect_handler): if not self.is_media and callable(self.client.disconnect_handler):
try: try:
await self.client.disconnect_handler(self.client) await self.client.disconnect_handler(self.client)
@ -185,8 +188,7 @@ class Session:
self.stored_msg_ids self.stored_msg_ids
) )
except SecurityCheckMismatch as e: except SecurityCheckMismatch as e:
log.info("Discarding packet: %s", e) log.warning("Discarding packet: %s", e)
await self.connection.close()
return return
messages = ( messages = (
@ -282,6 +284,9 @@ class Session:
message = self.msg_factory(data) message = self.msg_factory(data)
msg_id = message.msg_id msg_id = message.msg_id
if wait_response:
self.results[msg_id] = Result()
log.debug("Sent: %s", message) log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor( payload = await self.loop.run_in_executor(
@ -294,35 +299,34 @@ class Session:
self.auth_key_id self.auth_key_id
) )
await self.connection.send(payload) try:
await self.connection.send(payload)
except OSError as e:
self.results.pop(msg_id, None)
raise e
if wait_response: if wait_response:
self.results[msg_id] = Result()
try: try:
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
finally:
result = self.results.pop(msg_id).value result = self.results.pop(msg_id).value
if result is None: if result is None:
raise TimeoutError("Request timed out") raise TimeoutError("Request timed out")
elif isinstance(result, raw.types.RpcError):
if isinstance(result, raw.types.RpcError):
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)): if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
data = data.query data = data.query
RPCError.raise_it(result, type(data)) RPCError.raise_it(result, type(data))
elif isinstance(result, raw.types.BadMsgNotification):
if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code) raise BadMsgNotification(result.error_code)
elif isinstance(result, raw.types.BadServerSalt):
if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt self.salt = result.new_server_salt
return await self.send(data, wait_response, timeout) return await self.send(data, wait_response, timeout)
else:
return result return result
async def invoke( async def invoke(
self, self,

View File

@ -38,13 +38,13 @@ class FileStorage(SQLiteStorage):
version = self.version() version = self.version()
if version == 1: if version == 1:
with self.conn: with self.lock, self.conn:
self.conn.execute("DELETE FROM peers") self.conn.execute("DELETE FROM peers")
version += 1 version += 1
if version == 2: if version == 2:
with self.conn: with self.lock, self.conn:
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER") self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1 version += 1
@ -63,7 +63,10 @@ class FileStorage(SQLiteStorage):
self.update() self.update()
with self.conn: with self.conn:
self.conn.execute("VACUUM") try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
self.conn.execute("VACUUM")
except sqlite3.OperationalError:
pass
async def delete(self): async def delete(self):
os.remove(self.database) os.remove(self.database)

View File

@ -19,6 +19,7 @@
import inspect import inspect
import sqlite3 import sqlite3
import time import time
from threading import Lock
from typing import List, Tuple, Any from typing import List, Tuple, Any
from pyrogram import raw from pyrogram import raw
@ -97,9 +98,10 @@ class SQLiteStorage(Storage):
super().__init__(name) super().__init__(name)
self.conn = None # type: sqlite3.Connection self.conn = None # type: sqlite3.Connection
self.lock = Lock()
def create(self): def create(self):
with self.conn: with self.lock, self.conn:
self.conn.executescript(SCHEMA) self.conn.executescript(SCHEMA)
self.conn.execute( self.conn.execute(
@ -117,20 +119,24 @@ class SQLiteStorage(Storage):
async def save(self): async def save(self):
await self.date(int(time.time())) await self.date(int(time.time()))
self.conn.commit()
with self.lock:
self.conn.commit()
async def close(self): async def close(self):
self.conn.close() with self.lock:
self.conn.close()
async def delete(self): async def delete(self):
raise NotImplementedError raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
self.conn.executemany( with self.lock:
"REPLACE INTO peers (id, access_hash, type, username, phone_number)" self.conn.executemany(
"VALUES (?, ?, ?, ?, ?)", "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
peers "VALUES (?, ?, ?, ?, ?)",
) peers
)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute( r = self.conn.execute(
@ -179,7 +185,7 @@ class SQLiteStorage(Storage):
def _set(self, value: Any): def _set(self, value: Any):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
with self.conn: with self.lock, self.conn:
self.conn.execute( self.conn.execute(
f"UPDATE sessions SET {attr} = ?", f"UPDATE sessions SET {attr} = ?",
(value,) (value,)
@ -215,7 +221,7 @@ class SQLiteStorage(Storage):
"SELECT number FROM version" "SELECT number FROM version"
).fetchone()[0] ).fetchone()[0]
else: else:
with self.conn: with self.lock, self.conn:
self.conn.execute( self.conn.execute(
"UPDATE version SET number = ?", "UPDATE version SET number = ?",
(value,) (value,)