mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-25 11:28:05 +00:00
Revert some of the latest changes
This commit is contained in:
parent
a81b8a2254
commit
bff583ed75
@ -26,7 +26,6 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from io import StringIO, BytesIO
|
from io import StringIO, BytesIO
|
||||||
@ -186,9 +185,6 @@ class Client(Methods):
|
|||||||
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
|
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
|
||||||
WORKDIR = PARENT_DIR
|
WORKDIR = PARENT_DIR
|
||||||
|
|
||||||
# Interval of seconds in which the updates watchdog will kick in
|
|
||||||
UPDATES_WATCHDOG_INTERVAL = 5 * 60
|
|
||||||
|
|
||||||
mimetypes = MimeTypes()
|
mimetypes = MimeTypes()
|
||||||
mimetypes.readfp(StringIO(mime_types))
|
mimetypes.readfp(StringIO(mime_types))
|
||||||
|
|
||||||
@ -277,13 +273,6 @@ class Client(Methods):
|
|||||||
|
|
||||||
self.message_cache = Cache(10000)
|
self.message_cache = Cache(10000)
|
||||||
|
|
||||||
# Sometimes, for some reason, the server will stop sending updates and will only respond to pings.
|
|
||||||
# This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again
|
|
||||||
# after some idle time has been detected.
|
|
||||||
self.updates_watchdog_task = None
|
|
||||||
self.updates_watchdog_event = asyncio.Event()
|
|
||||||
self.last_update_time = datetime.now()
|
|
||||||
|
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -304,18 +293,6 @@ class Client(Methods):
|
|||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def updates_watchdog(self):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL):
|
|
||||||
await self.invoke(raw.functions.updates.GetState())
|
|
||||||
|
|
||||||
async def authorize(self) -> User:
|
async def authorize(self) -> User:
|
||||||
if self.bot_token:
|
if self.bot_token:
|
||||||
return await self.sign_in_bot(self.bot_token)
|
return await self.sign_in_bot(self.bot_token)
|
||||||
@ -508,8 +485,6 @@ class Client(Methods):
|
|||||||
return is_min
|
return is_min
|
||||||
|
|
||||||
async def handle_updates(self, updates):
|
async def handle_updates(self, updates):
|
||||||
self.last_update_time = datetime.now()
|
|
||||||
|
|
||||||
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
|
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
|
||||||
is_min = any((
|
is_min = any((
|
||||||
await self.fetch_peers(updates.users),
|
await self.fetch_peers(updates.users),
|
||||||
|
@ -48,7 +48,7 @@ class Connection:
|
|||||||
await self.protocol.connect(self.address)
|
await self.protocol.connect(self.address)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
log.warning("Unable to connect due to network issues: %s", e)
|
log.warning("Unable to connect due to network issues: %s", e)
|
||||||
await self.protocol.close()
|
self.protocol.close()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
else:
|
else:
|
||||||
log.info("Connected! %s DC%s%s - IPv%s",
|
log.info("Connected! %s DC%s%s - IPv%s",
|
||||||
@ -59,14 +59,17 @@ class Connection:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
log.warning("Connection failed! Trying again...")
|
log.warning("Connection failed! Trying again...")
|
||||||
raise ConnectionError
|
raise TimeoutError
|
||||||
|
|
||||||
async def close(self):
|
def close(self):
|
||||||
await self.protocol.close()
|
self.protocol.close()
|
||||||
log.info("Disconnected")
|
log.info("Disconnected")
|
||||||
|
|
||||||
async def send(self, data: bytes):
|
async def send(self, data: bytes):
|
||||||
await self.protocol.send(data)
|
try:
|
||||||
|
await self.protocol.send(data)
|
||||||
|
except Exception as e:
|
||||||
|
raise OSError(e)
|
||||||
|
|
||||||
async def recv(self) -> Optional[bytes]:
|
async def recv(self) -> Optional[bytes]:
|
||||||
return await self.protocol.recv()
|
return await self.protocol.recv()
|
||||||
|
@ -20,6 +20,9 @@ import asyncio
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import socks
|
import socks
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -31,12 +34,10 @@ class TCP:
|
|||||||
def __init__(self, ipv6: bool, proxy: dict):
|
def __init__(self, ipv6: bool, proxy: dict):
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
self.reader = None
|
self.reader = None # type: asyncio.StreamReader
|
||||||
self.writer = None
|
self.writer = None # type: asyncio.StreamWriter
|
||||||
|
|
||||||
self.send_queue = asyncio.Queue()
|
|
||||||
self.send_task = None
|
|
||||||
|
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
if proxy:
|
if proxy:
|
||||||
@ -62,50 +63,39 @@ class TCP:
|
|||||||
|
|
||||||
log.info("Using proxy %s", hostname)
|
log.info("Using proxy %s", hostname)
|
||||||
else:
|
else:
|
||||||
self.socket = socket.socket(
|
self.socket = socks.socksocket(
|
||||||
socket.AF_INET6 if ipv6
|
socket.AF_INET6 if ipv6
|
||||||
else socket.AF_INET
|
else socket.AF_INET
|
||||||
)
|
)
|
||||||
|
|
||||||
self.socket.setblocking(False)
|
self.socket.settimeout(TCP.TIMEOUT)
|
||||||
|
|
||||||
async def connect(self, address: tuple):
|
async def connect(self, address: tuple):
|
||||||
try:
|
# The socket used by the whole logic is blocking and thus it blocks when connecting.
|
||||||
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
|
# Offload the task to a thread executor to avoid blocking the main event loop.
|
||||||
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
|
with ThreadPoolExecutor(1) as executor:
|
||||||
raise TimeoutError("Connection timed out")
|
await self.loop.run_in_executor(executor, self.socket.connect, address)
|
||||||
|
|
||||||
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
|
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
|
||||||
self.send_task = asyncio.create_task(self.send_worker())
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
await self.send_queue.put(None)
|
|
||||||
|
|
||||||
if self.send_task is not None:
|
|
||||||
await self.send_task
|
|
||||||
|
|
||||||
|
def close(self):
|
||||||
try:
|
try:
|
||||||
if self.writer is not None:
|
self.writer.close()
|
||||||
self.writer.close()
|
except AttributeError:
|
||||||
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
|
try:
|
||||||
except Exception as e:
|
self.socket.shutdown(socket.SHUT_RDWR)
|
||||||
log.info("Close exception: %s %s", type(e).__name__, e)
|
except OSError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
# A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
|
||||||
|
# This is a workaround that seems to fix the occasional delayed stop of a client.
|
||||||
|
time.sleep(0.001)
|
||||||
|
self.socket.close()
|
||||||
|
|
||||||
async def send(self, data: bytes):
|
async def send(self, data: bytes):
|
||||||
await self.send_queue.put(data)
|
async with self.lock:
|
||||||
|
self.writer.write(data)
|
||||||
async def send_worker(self):
|
await self.writer.drain()
|
||||||
while True:
|
|
||||||
data = await self.send_queue.get()
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.writer.write(data)
|
|
||||||
await self.writer.drain()
|
|
||||||
except Exception as e:
|
|
||||||
log.info("Send exception: %s %s", type(e).__name__, e)
|
|
||||||
|
|
||||||
async def recv(self, length: int = 0):
|
async def recv(self, length: int = 0):
|
||||||
data = b""
|
data = b""
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
@ -47,6 +46,4 @@ class Initialize:
|
|||||||
|
|
||||||
await self.dispatcher.start()
|
await self.dispatcher.start()
|
||||||
|
|
||||||
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
|
|
||||||
|
|
||||||
self.is_initialized = True
|
self.is_initialized = True
|
||||||
|
@ -51,11 +51,4 @@ class Terminate:
|
|||||||
|
|
||||||
self.media_sessions.clear()
|
self.media_sessions.clear()
|
||||||
|
|
||||||
self.updates_watchdog_event.set()
|
|
||||||
|
|
||||||
if self.updates_watchdog_task is not None:
|
|
||||||
await self.updates_watchdog_task
|
|
||||||
|
|
||||||
self.updates_watchdog_event.clear()
|
|
||||||
|
|
||||||
self.is_initialized = False
|
self.is_initialized = False
|
||||||
|
@ -278,4 +278,4 @@ class Auth:
|
|||||||
else:
|
else:
|
||||||
return auth_key
|
return auth_key
|
||||||
finally:
|
finally:
|
||||||
await self.connection.close()
|
self.connection.close()
|
||||||
|
@ -16,15 +16,19 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
|
||||||
class SeqNo:
|
class SeqNo:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.content_related_messages_sent = 0
|
self.content_related_messages_sent = 0
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
def __call__(self, is_content_related: bool) -> int:
|
def __call__(self, is_content_related: bool) -> int:
|
||||||
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
|
with self.lock:
|
||||||
|
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
|
||||||
|
|
||||||
if is_content_related:
|
if is_content_related:
|
||||||
self.content_related_messages_sent += 1
|
self.content_related_messages_sent += 1
|
||||||
|
|
||||||
return seq_no
|
return seq_no
|
||||||
|
@ -156,11 +156,14 @@ class Session:
|
|||||||
|
|
||||||
self.ping_task_event.clear()
|
self.ping_task_event.clear()
|
||||||
|
|
||||||
await self.connection.close()
|
self.connection.close()
|
||||||
|
|
||||||
if self.recv_task:
|
if self.recv_task:
|
||||||
await self.recv_task
|
await self.recv_task
|
||||||
|
|
||||||
|
for i in self.results.values():
|
||||||
|
i.event.set()
|
||||||
|
|
||||||
if not self.is_media and callable(self.client.disconnect_handler):
|
if not self.is_media and callable(self.client.disconnect_handler):
|
||||||
try:
|
try:
|
||||||
await self.client.disconnect_handler(self.client)
|
await self.client.disconnect_handler(self.client)
|
||||||
@ -185,8 +188,7 @@ class Session:
|
|||||||
self.stored_msg_ids
|
self.stored_msg_ids
|
||||||
)
|
)
|
||||||
except SecurityCheckMismatch as e:
|
except SecurityCheckMismatch as e:
|
||||||
log.info("Discarding packet: %s", e)
|
log.warning("Discarding packet: %s", e)
|
||||||
await self.connection.close()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
messages = (
|
messages = (
|
||||||
@ -282,6 +284,9 @@ class Session:
|
|||||||
message = self.msg_factory(data)
|
message = self.msg_factory(data)
|
||||||
msg_id = message.msg_id
|
msg_id = message.msg_id
|
||||||
|
|
||||||
|
if wait_response:
|
||||||
|
self.results[msg_id] = Result()
|
||||||
|
|
||||||
log.debug("Sent: %s", message)
|
log.debug("Sent: %s", message)
|
||||||
|
|
||||||
payload = await self.loop.run_in_executor(
|
payload = await self.loop.run_in_executor(
|
||||||
@ -294,35 +299,34 @@ class Session:
|
|||||||
self.auth_key_id
|
self.auth_key_id
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.connection.send(payload)
|
try:
|
||||||
|
await self.connection.send(payload)
|
||||||
|
except OSError as e:
|
||||||
|
self.results.pop(msg_id, None)
|
||||||
|
raise e
|
||||||
|
|
||||||
if wait_response:
|
if wait_response:
|
||||||
self.results[msg_id] = Result()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
|
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
result = self.results.pop(msg_id).value
|
result = self.results.pop(msg_id).value
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise TimeoutError("Request timed out")
|
raise TimeoutError("Request timed out")
|
||||||
|
elif isinstance(result, raw.types.RpcError):
|
||||||
if isinstance(result, raw.types.RpcError):
|
|
||||||
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
|
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
|
||||||
data = data.query
|
data = data.query
|
||||||
|
|
||||||
RPCError.raise_it(result, type(data))
|
RPCError.raise_it(result, type(data))
|
||||||
|
elif isinstance(result, raw.types.BadMsgNotification):
|
||||||
if isinstance(result, raw.types.BadMsgNotification):
|
|
||||||
raise BadMsgNotification(result.error_code)
|
raise BadMsgNotification(result.error_code)
|
||||||
|
elif isinstance(result, raw.types.BadServerSalt):
|
||||||
if isinstance(result, raw.types.BadServerSalt):
|
|
||||||
self.salt = result.new_server_salt
|
self.salt = result.new_server_salt
|
||||||
return await self.send(data, wait_response, timeout)
|
return await self.send(data, wait_response, timeout)
|
||||||
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def invoke(
|
async def invoke(
|
||||||
self,
|
self,
|
||||||
|
@ -38,13 +38,13 @@ class FileStorage(SQLiteStorage):
|
|||||||
version = self.version()
|
version = self.version()
|
||||||
|
|
||||||
if version == 1:
|
if version == 1:
|
||||||
with self.conn:
|
with self.lock, self.conn:
|
||||||
self.conn.execute("DELETE FROM peers")
|
self.conn.execute("DELETE FROM peers")
|
||||||
|
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
if version == 2:
|
if version == 2:
|
||||||
with self.conn:
|
with self.lock, self.conn:
|
||||||
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
|
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
|
||||||
|
|
||||||
version += 1
|
version += 1
|
||||||
@ -63,7 +63,10 @@ class FileStorage(SQLiteStorage):
|
|||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
with self.conn:
|
with self.conn:
|
||||||
self.conn.execute("VACUUM")
|
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
|
||||||
|
self.conn.execute("VACUUM")
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
os.remove(self.database)
|
os.remove(self.database)
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
|
from threading import Lock
|
||||||
from typing import List, Tuple, Any
|
from typing import List, Tuple, Any
|
||||||
|
|
||||||
from pyrogram import raw
|
from pyrogram import raw
|
||||||
@ -97,9 +98,10 @@ class SQLiteStorage(Storage):
|
|||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
|
||||||
self.conn = None # type: sqlite3.Connection
|
self.conn = None # type: sqlite3.Connection
|
||||||
|
self.lock = Lock()
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
with self.conn:
|
with self.lock, self.conn:
|
||||||
self.conn.executescript(SCHEMA)
|
self.conn.executescript(SCHEMA)
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
@ -117,20 +119,24 @@ class SQLiteStorage(Storage):
|
|||||||
|
|
||||||
async def save(self):
|
async def save(self):
|
||||||
await self.date(int(time.time()))
|
await self.date(int(time.time()))
|
||||||
self.conn.commit()
|
|
||||||
|
with self.lock:
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self.conn.close()
|
with self.lock:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||||
self.conn.executemany(
|
with self.lock:
|
||||||
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
|
self.conn.executemany(
|
||||||
"VALUES (?, ?, ?, ?, ?)",
|
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
|
||||||
peers
|
"VALUES (?, ?, ?, ?, ?)",
|
||||||
)
|
peers
|
||||||
|
)
|
||||||
|
|
||||||
async def get_peer_by_id(self, peer_id: int):
|
async def get_peer_by_id(self, peer_id: int):
|
||||||
r = self.conn.execute(
|
r = self.conn.execute(
|
||||||
@ -179,7 +185,7 @@ class SQLiteStorage(Storage):
|
|||||||
def _set(self, value: Any):
|
def _set(self, value: Any):
|
||||||
attr = inspect.stack()[2].function
|
attr = inspect.stack()[2].function
|
||||||
|
|
||||||
with self.conn:
|
with self.lock, self.conn:
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
f"UPDATE sessions SET {attr} = ?",
|
f"UPDATE sessions SET {attr} = ?",
|
||||||
(value,)
|
(value,)
|
||||||
@ -215,7 +221,7 @@ class SQLiteStorage(Storage):
|
|||||||
"SELECT number FROM version"
|
"SELECT number FROM version"
|
||||||
).fetchone()[0]
|
).fetchone()[0]
|
||||||
else:
|
else:
|
||||||
with self.conn:
|
with self.lock, self.conn:
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"UPDATE version SET number = ?",
|
"UPDATE version SET number = ?",
|
||||||
(value,)
|
(value,)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user