mirror of
https://github.com/pyrogram/pyrogram
synced 2025-09-02 07:15:23 +00:00
Better handling of expiring server salts
This commit is contained in:
@@ -19,8 +19,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@@ -33,7 +31,7 @@ from pyrogram.errors import (
|
|||||||
SecurityCheckMismatch
|
SecurityCheckMismatch
|
||||||
)
|
)
|
||||||
from pyrogram.raw.all import layer
|
from pyrogram.raw.all import layer
|
||||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
|
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
||||||
from .internals import MsgId, MsgFactory
|
from .internals import MsgId, MsgFactory
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -76,7 +74,7 @@ class Session:
|
|||||||
self.session_id = os.urandom(8)
|
self.session_id = os.urandom(8)
|
||||||
self.msg_factory = MsgFactory()
|
self.msg_factory = MsgFactory()
|
||||||
|
|
||||||
self.current_salt = None
|
self.salt = 0
|
||||||
|
|
||||||
self.pending_acks = set()
|
self.pending_acks = set()
|
||||||
|
|
||||||
@@ -87,9 +85,6 @@ class Session:
|
|||||||
self.ping_task = None
|
self.ping_task = None
|
||||||
self.ping_task_event = asyncio.Event()
|
self.ping_task_event = asyncio.Event()
|
||||||
|
|
||||||
self.next_salt_task = None
|
|
||||||
self.next_salt_task_event = asyncio.Event()
|
|
||||||
|
|
||||||
self.network_task = None
|
self.network_task = None
|
||||||
|
|
||||||
self.is_connected = asyncio.Event()
|
self.is_connected = asyncio.Event()
|
||||||
@@ -111,19 +106,7 @@ class Session:
|
|||||||
|
|
||||||
self.network_task = self.loop.create_task(self.network_worker())
|
self.network_task = self.loop.create_task(self.network_worker())
|
||||||
|
|
||||||
self.current_salt = FutureSalt(0, 0, 0)
|
await self._send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||||
self.current_salt = FutureSalt(
|
|
||||||
0, 0,
|
|
||||||
(await self._send(
|
|
||||||
raw.functions.Ping(ping_id=0),
|
|
||||||
timeout=self.START_TIMEOUT
|
|
||||||
)).new_server_salt
|
|
||||||
)
|
|
||||||
self.current_salt = (await self._send(
|
|
||||||
raw.functions.GetFutureSalts(num=1),
|
|
||||||
timeout=self.START_TIMEOUT)).salts[0]
|
|
||||||
|
|
||||||
self.next_salt_task = self.loop.create_task(self.next_salt_worker())
|
|
||||||
|
|
||||||
if not self.is_cdn:
|
if not self.is_cdn:
|
||||||
await self._send(
|
await self._send(
|
||||||
@@ -168,16 +151,11 @@ class Session:
|
|||||||
self.is_connected.clear()
|
self.is_connected.clear()
|
||||||
|
|
||||||
self.ping_task_event.set()
|
self.ping_task_event.set()
|
||||||
self.next_salt_task_event.set()
|
|
||||||
|
|
||||||
if self.ping_task is not None:
|
if self.ping_task is not None:
|
||||||
await self.ping_task
|
await self.ping_task
|
||||||
|
|
||||||
if self.next_salt_task is not None:
|
|
||||||
await self.next_salt_task
|
|
||||||
|
|
||||||
self.ping_task_event.clear()
|
self.ping_task_event.clear()
|
||||||
self.next_salt_task_event.clear()
|
|
||||||
|
|
||||||
self.connection.close()
|
self.connection.close()
|
||||||
|
|
||||||
@@ -288,35 +266,6 @@ class Session:
|
|||||||
|
|
||||||
log.info("PingTask stopped")
|
log.info("PingTask stopped")
|
||||||
|
|
||||||
async def next_salt_worker(self):
|
|
||||||
log.info("NextSaltTask started")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
now = datetime.fromtimestamp(time.perf_counter() - MsgId.reference_clock + MsgId.server_time)
|
|
||||||
|
|
||||||
# Seconds to wait until middle-overlap, which is
|
|
||||||
# 15 minutes before/after the current/next salt end/start time
|
|
||||||
valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
|
|
||||||
dt = (valid_until - now).total_seconds() - 900
|
|
||||||
|
|
||||||
minutes, seconds = divmod(int(dt), 60)
|
|
||||||
log.info(f"Next salt in {minutes:.0f}m {seconds:.0f}s (at {now + timedelta(seconds=dt)})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.next_salt_task_event.wait(), dt)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.current_salt = (await self._send(raw.functions.GetFutureSalts(num=1))).salts[0]
|
|
||||||
except (OSError, TimeoutError, RPCError):
|
|
||||||
self.connection.close()
|
|
||||||
break
|
|
||||||
|
|
||||||
log.info("NextSaltTask stopped")
|
|
||||||
|
|
||||||
async def network_worker(self):
|
async def network_worker(self):
|
||||||
log.info("NetworkTask started")
|
log.info("NetworkTask started")
|
||||||
|
|
||||||
@@ -352,7 +301,7 @@ class Session:
|
|||||||
pyrogram.crypto_executor,
|
pyrogram.crypto_executor,
|
||||||
mtproto.pack,
|
mtproto.pack,
|
||||||
message,
|
message,
|
||||||
self.current_salt.salt,
|
self.salt,
|
||||||
self.session_id,
|
self.session_id,
|
||||||
self.auth_key,
|
self.auth_key,
|
||||||
self.auth_key_id
|
self.auth_key_id
|
||||||
@@ -381,6 +330,9 @@ class Session:
|
|||||||
RPCError.raise_it(result, type(data))
|
RPCError.raise_it(result, type(data))
|
||||||
elif isinstance(result, raw.types.BadMsgNotification):
|
elif isinstance(result, raw.types.BadMsgNotification):
|
||||||
raise BadMsgNotification(result.error_code)
|
raise BadMsgNotification(result.error_code)
|
||||||
|
elif isinstance(result, raw.types.BadServerSalt):
|
||||||
|
self.salt = result.new_server_salt
|
||||||
|
return await self._send(data, wait_response, timeout)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user