2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Merge develop -> asyncio

This commit is contained in:
Dan 2020-05-07 13:39:48 +02:00
commit db4a00da36
2 changed files with 41 additions and 33 deletions

View File

@ -184,7 +184,7 @@ class Client(Methods, BaseClient):
plugins: dict = None, plugins: dict = None,
no_updates: bool = None, no_updates: bool = None,
takeout: bool = None, takeout: bool = None,
sleep_threshold: int = 60 sleep_threshold: int = Session.SLEEP_THRESHOLD
): ):
super().__init__() super().__init__()
@ -1410,31 +1410,13 @@ class Client(Methods, BaseClient):
if not self.is_connected: if not self.is_connected:
raise ConnectionError("Client has not been started yet") raise ConnectionError("Client has not been started yet")
# Some raw methods that expect a query as argument are used here.
# Keep the original request query because is needed.
unwrapped_data = data
if self.no_updates: if self.no_updates:
data = functions.InvokeWithoutUpdates(query=data) data = functions.InvokeWithoutUpdates(query=data)
if self.takeout_id: if self.takeout_id:
data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data) data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data)
while True: r = await self.session.send(data, retries, timeout, self.sleep_threshold)
try:
r = await self.session.send(data, retries, timeout)
except FloodWait as e:
amount = e.x
if amount > self.sleep_threshold:
raise
log.warning('[{}] Sleeping for {}s (required by "{}")'.format(
self.session_name, amount, ".".join(unwrapped_data.QUALNAME.split(".")[1:])))
await asyncio.sleep(amount)
else:
break
self.fetch_peers(getattr(r, "users", [])) self.fetch_peers(getattr(r, "users", []))
self.fetch_peers(getattr(r, "chats", [])) self.fetch_peers(getattr(r, "chats", []))

View File

@ -29,7 +29,7 @@ from pyrogram.api.all import layer
from pyrogram.api.core import TLObject, MsgContainer, Int, Long, FutureSalt, FutureSalts from pyrogram.api.core import TLObject, MsgContainer, Int, Long, FutureSalt, FutureSalts
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import MTProto from pyrogram.crypto import MTProto
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -46,6 +46,7 @@ class Session:
NET_WORKERS = 1 NET_WORKERS = 1
START_TIMEOUT = 1 START_TIMEOUT = 1
WAIT_TIMEOUT = 15 WAIT_TIMEOUT = 15
SLEEP_THRESHOLD = 60
MAX_RETRIES = 5 MAX_RETRIES = 5
ACKS_THRESHOLD = 8 ACKS_THRESHOLD = 8
PING_INTERVAL = 5 PING_INTERVAL = 5
@ -402,22 +403,47 @@ class Session:
else: else:
return result return result
async def send(self, data: TLObject, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT): async def send(
self,
data: TLObject,
retries: int = MAX_RETRIES,
timeout: float = WAIT_TIMEOUT,
sleep_threshold: float = SLEEP_THRESHOLD
):
try: try:
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
try: if isinstance(data, (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)):
return await self._send(data, timeout=timeout) query = data.query
except (OSError, TimeoutError, InternalServerError) as e: else:
if retries == 0: query = data
raise e from None
(log.warning if retries < 2 else log.info)( query = ".".join(query.QUALNAME.split(".")[1:])
"[{}] Retrying {} due to {}".format(
Session.MAX_RETRIES - retries + 1,
data.QUALNAME, e))
await asyncio.sleep(0.5) while True:
return await self.send(data, retries - 1, timeout) try:
return await self._send(data, timeout=timeout)
except FloodWait as e:
amount = e.x
if amount > sleep_threshold:
raise
log.warning('[{}] Sleeping for {}s (required by "{}")'.format(
self.client.session_name, amount, query))
await asyncio.sleep(amount)
except (OSError, TimeoutError, InternalServerError) as e:
if retries == 0:
raise e from None
(log.warning if retries < 2 else log.info)(
'[{}] Retrying "{}" due to {}'.format(
Session.MAX_RETRIES - retries + 1,
query, e))
await asyncio.sleep(0.5)
return await self.send(data, retries - 1, timeout)