From c1321a4c0115906814f687a5dab4cf91e094ee54 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Sat, 10 Aug 2019 22:37:07 +0200 Subject: [PATCH] Add smarter auth import to deal with race conditions by multi sessions - Add a retry mechanism (up to three times) - Narrow the window in which export+import executes - Remove a line of duplicated code Fixes #299 --- pyrogram/client/client.py | 41 ++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 0e1f9bef..5fdcf235 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -45,7 +45,7 @@ from pyrogram.errors import ( PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned, VolumeLocNotFound, UserMigrate, ChannelPrivate, PhoneNumberOccupied, - PasswordRecoveryNa, PasswordEmpty + PasswordRecoveryNa, PasswordEmpty, AuthBytesInvalid ) from pyrogram.session import Auth, Session from .ext import utils, Syncer, BaseClient, Dispatcher @@ -1229,7 +1229,7 @@ class Client(Methods, BaseClient): def load_config(self): parser = ConfigParser() parser.read(str(self.config_file)) - + if self.bot_token: pass else: @@ -1720,30 +1720,35 @@ class Client(Methods, BaseClient): if session is None: if dc_id != self.storage.dc_id: - exported_auth = self.send( - functions.auth.ExportAuthorization( - dc_id=dc_id - ) - ) - session = Session(self, dc_id, Auth(self, dc_id).create(), is_media=True) - session.start() - self.media_sessions[dc_id] = session - - session.send( - functions.auth.ImportAuthorization( - id=exported_auth.id, - bytes=exported_auth.bytes + for _ in range(3): + exported_auth = self.send( + functions.auth.ExportAuthorization( + dc_id=dc_id + ) ) - ) + + try: + session.send( + functions.auth.ImportAuthorization( + id=exported_auth.id, + bytes=exported_auth.bytes + ) + ) + except AuthBytesInvalid: + continue + else: + break + else: + session.stop() + raise AuthBytesInvalid else: session = Session(self, dc_id, self.storage.auth_key, is_media=True) - session.start() - self.media_sessions[dc_id] = session + self.media_sessions[dc_id] = session if media_type == 1: location = types.InputPeerPhotoFileLocation(