2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-29 05:18:10 +00:00

Better handling of expiring server salts

This commit is contained in:
Dan 2022-01-20 09:43:29 +01:00
parent e67fd6efbb
commit 1162e89f26

View File

@ -19,8 +19,6 @@
import asyncio
import logging
import os
import time
from datetime import datetime, timedelta
from hashlib import sha1
from io import BytesIO
@ -33,7 +31,7 @@ from pyrogram.errors import (
SecurityCheckMismatch
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
@ -76,7 +74,7 @@ class Session:
self.session_id = os.urandom(8)
self.msg_factory = MsgFactory()
self.current_salt = None
self.salt = 0
self.pending_acks = set()
@ -87,9 +85,6 @@ class Session:
self.ping_task = None
self.ping_task_event = asyncio.Event()
self.next_salt_task = None
self.next_salt_task_event = asyncio.Event()
self.network_task = None
self.is_connected = asyncio.Event()
@ -111,19 +106,7 @@ class Session:
self.network_task = self.loop.create_task(self.network_worker())
self.current_salt = FutureSalt(0, 0, 0)
self.current_salt = FutureSalt(
0, 0,
(await self._send(
raw.functions.Ping(ping_id=0),
timeout=self.START_TIMEOUT
)).new_server_salt
)
self.current_salt = (await self._send(
raw.functions.GetFutureSalts(num=1),
timeout=self.START_TIMEOUT)).salts[0]
self.next_salt_task = self.loop.create_task(self.next_salt_worker())
await self._send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
if not self.is_cdn:
await self._send(
@ -168,16 +151,11 @@ class Session:
self.is_connected.clear()
self.ping_task_event.set()
self.next_salt_task_event.set()
if self.ping_task is not None:
await self.ping_task
if self.next_salt_task is not None:
await self.next_salt_task
self.ping_task_event.clear()
self.next_salt_task_event.clear()
self.connection.close()
@ -288,35 +266,6 @@ class Session:
log.info("PingTask stopped")
async def next_salt_worker(self):
log.info("NextSaltTask started")
while True:
now = datetime.fromtimestamp(time.perf_counter() - MsgId.reference_clock + MsgId.server_time)
# Seconds to wait until middle-overlap, which is
# 15 minutes before/after the current/next salt end/start time
valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
dt = (valid_until - now).total_seconds() - 900
minutes, seconds = divmod(int(dt), 60)
log.info(f"Next salt in {minutes:.0f}m {seconds:.0f}s (at {now + timedelta(seconds=dt)})")
try:
await asyncio.wait_for(self.next_salt_task_event.wait(), dt)
except asyncio.TimeoutError:
pass
else:
break
try:
self.current_salt = (await self._send(raw.functions.GetFutureSalts(num=1))).salts[0]
except (OSError, TimeoutError, RPCError):
self.connection.close()
break
log.info("NextSaltTask stopped")
async def network_worker(self):
log.info("NetworkTask started")
@ -352,7 +301,7 @@ class Session:
pyrogram.crypto_executor,
mtproto.pack,
message,
self.current_salt.salt,
self.salt,
self.session_id,
self.auth_key,
self.auth_key_id
@ -381,6 +330,9 @@ class Session:
RPCError.raise_it(result, type(data))
elif isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code)
elif isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt
return await self._send(data, wait_response, timeout)
else:
return result