2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-09-02 07:15:23 +00:00

Better handling of expiring server salts

This commit is contained in:
Dan
2022-01-20 09:43:29 +01:00
parent e67fd6efbb
commit 1162e89f26

View File

@@ -19,8 +19,6 @@
import asyncio import asyncio
import logging import logging
import os import os
import time
from datetime import datetime, timedelta
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
@@ -33,7 +31,7 @@ from pyrogram.errors import (
SecurityCheckMismatch SecurityCheckMismatch
) )
from pyrogram.raw.all import layer from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -76,7 +74,7 @@ class Session:
self.session_id = os.urandom(8) self.session_id = os.urandom(8)
self.msg_factory = MsgFactory() self.msg_factory = MsgFactory()
self.current_salt = None self.salt = 0
self.pending_acks = set() self.pending_acks = set()
@@ -87,9 +85,6 @@ class Session:
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event() self.ping_task_event = asyncio.Event()
self.next_salt_task = None
self.next_salt_task_event = asyncio.Event()
self.network_task = None self.network_task = None
self.is_connected = asyncio.Event() self.is_connected = asyncio.Event()
@@ -111,19 +106,7 @@ class Session:
self.network_task = self.loop.create_task(self.network_worker()) self.network_task = self.loop.create_task(self.network_worker())
self.current_salt = FutureSalt(0, 0, 0) await self._send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
self.current_salt = FutureSalt(
0, 0,
(await self._send(
raw.functions.Ping(ping_id=0),
timeout=self.START_TIMEOUT
)).new_server_salt
)
self.current_salt = (await self._send(
raw.functions.GetFutureSalts(num=1),
timeout=self.START_TIMEOUT)).salts[0]
self.next_salt_task = self.loop.create_task(self.next_salt_worker())
if not self.is_cdn: if not self.is_cdn:
await self._send( await self._send(
@@ -168,16 +151,11 @@ class Session:
self.is_connected.clear() self.is_connected.clear()
self.ping_task_event.set() self.ping_task_event.set()
self.next_salt_task_event.set()
if self.ping_task is not None: if self.ping_task is not None:
await self.ping_task await self.ping_task
if self.next_salt_task is not None:
await self.next_salt_task
self.ping_task_event.clear() self.ping_task_event.clear()
self.next_salt_task_event.clear()
self.connection.close() self.connection.close()
@@ -288,35 +266,6 @@ class Session:
log.info("PingTask stopped") log.info("PingTask stopped")
async def next_salt_worker(self):
log.info("NextSaltTask started")
while True:
now = datetime.fromtimestamp(time.perf_counter() - MsgId.reference_clock + MsgId.server_time)
# Seconds to wait until middle-overlap, which is
# 15 minutes before/after the current/next salt end/start time
valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
dt = (valid_until - now).total_seconds() - 900
minutes, seconds = divmod(int(dt), 60)
log.info(f"Next salt in {minutes:.0f}m {seconds:.0f}s (at {now + timedelta(seconds=dt)})")
try:
await asyncio.wait_for(self.next_salt_task_event.wait(), dt)
except asyncio.TimeoutError:
pass
else:
break
try:
self.current_salt = (await self._send(raw.functions.GetFutureSalts(num=1))).salts[0]
except (OSError, TimeoutError, RPCError):
self.connection.close()
break
log.info("NextSaltTask stopped")
async def network_worker(self): async def network_worker(self):
log.info("NetworkTask started") log.info("NetworkTask started")
@@ -352,7 +301,7 @@ class Session:
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.pack, mtproto.pack,
message, message,
self.current_salt.salt, self.salt,
self.session_id, self.session_id,
self.auth_key, self.auth_key,
self.auth_key_id self.auth_key_id
@@ -381,6 +330,9 @@ class Session:
RPCError.raise_it(result, type(data)) RPCError.raise_it(result, type(data))
elif isinstance(result, raw.types.BadMsgNotification): elif isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code) raise BadMsgNotification(result.error_code)
elif isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt
return await self._send(data, wait_response, timeout)
else: else:
return result return result