From 1162e89f26cc5f623ba6edb851129afd8ce2cab7 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Thu, 20 Jan 2022 09:43:29 +0100 Subject: [PATCH] Better handling of expiring server salts --- pyrogram/session/session.py | 62 +++++-------------------------------- 1 file changed, 7 insertions(+), 55 deletions(-) diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 751d6e00..6455e958 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -19,8 +19,6 @@ import asyncio import logging import os -import time -from datetime import datetime, timedelta from hashlib import sha1 from io import BytesIO @@ -33,7 +31,7 @@ from pyrogram.errors import ( SecurityCheckMismatch ) from pyrogram.raw.all import layer -from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts +from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts from .internals import MsgId, MsgFactory log = logging.getLogger(__name__) @@ -76,7 +74,7 @@ class Session: self.session_id = os.urandom(8) self.msg_factory = MsgFactory() - self.current_salt = None + self.salt = 0 self.pending_acks = set() @@ -87,9 +85,6 @@ class Session: self.ping_task = None self.ping_task_event = asyncio.Event() - self.next_salt_task = None - self.next_salt_task_event = asyncio.Event() - self.network_task = None self.is_connected = asyncio.Event() @@ -111,19 +106,7 @@ class Session: self.network_task = self.loop.create_task(self.network_worker()) - self.current_salt = FutureSalt(0, 0, 0) - self.current_salt = FutureSalt( - 0, 0, - (await self._send( - raw.functions.Ping(ping_id=0), - timeout=self.START_TIMEOUT - )).new_server_salt - ) - self.current_salt = (await self._send( - raw.functions.GetFutureSalts(num=1), - timeout=self.START_TIMEOUT)).salts[0] - - self.next_salt_task = self.loop.create_task(self.next_salt_worker()) + await self._send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) if not self.is_cdn: await self._send( @@ -168,16 +151,11 @@ class Session: self.is_connected.clear() self.ping_task_event.set() - self.next_salt_task_event.set() if self.ping_task is not None: await self.ping_task - if self.next_salt_task is not None: - await self.next_salt_task - self.ping_task_event.clear() - self.next_salt_task_event.clear() self.connection.close() @@ -288,35 +266,6 @@ class Session: log.info("PingTask stopped") - async def next_salt_worker(self): - log.info("NextSaltTask started") - - while True: - now = datetime.fromtimestamp(time.perf_counter() - MsgId.reference_clock + MsgId.server_time) - - # Seconds to wait until middle-overlap, which is - # 15 minutes before/after the current/next salt end/start time - valid_until = datetime.fromtimestamp(self.current_salt.valid_until) - dt = (valid_until - now).total_seconds() - 900 - - minutes, seconds = divmod(int(dt), 60) - log.info(f"Next salt in {minutes:.0f}m {seconds:.0f}s (at {now + timedelta(seconds=dt)})") - - try: - await asyncio.wait_for(self.next_salt_task_event.wait(), dt) - except asyncio.TimeoutError: - pass - else: - break - - try: - self.current_salt = (await self._send(raw.functions.GetFutureSalts(num=1))).salts[0] - except (OSError, TimeoutError, RPCError): - self.connection.close() - break - - log.info("NextSaltTask stopped") - async def network_worker(self): log.info("NetworkTask started") @@ -352,7 +301,7 @@ class Session: pyrogram.crypto_executor, mtproto.pack, message, - self.current_salt.salt, + self.salt, self.session_id, self.auth_key, self.auth_key_id @@ -381,6 +330,9 @@ class Session: RPCError.raise_it(result, type(data)) elif isinstance(result, raw.types.BadMsgNotification): raise BadMsgNotification(result.error_code) + elif isinstance(result, raw.types.BadServerSalt): + self.salt = result.new_server_salt + return await self._send(data, wait_response, timeout) else: return result