From 83febf2e0cfb485830ace4993382cef35e2baaba Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Tue, 1 May 2018 18:15:33 +0200 Subject: [PATCH] Cache media sessions. Related to #40. Fixes #68 --- pyrogram/client/client.py | 126 +++++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 57 deletions(-) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 621952f5..feeea853 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -195,6 +195,8 @@ class Client: self.html = HTML(self.peers_by_id) self.session = None + self.media_sessions = {} + self.media_sessions_lock = threading.Lock() self.is_started = None self.is_idle = None @@ -399,13 +401,8 @@ class Client: if not self.is_started: raise ConnectionError("Client is already stopped") - for _ in range(self.UPDATES_WORKERS): - self.updates_queue.put(None) - - for i in self.updates_workers_list: - i.join() - - self.updates_workers_list.clear() + Syncer.remove(self) + self.dispatcher.stop() for _ in range(self.DOWNLOAD_WORKERS): self.download_queue.put(None) @@ -415,13 +412,20 @@ class Client: self.download_workers_list.clear() - self.dispatcher.stop() + for _ in range(self.UPDATES_WORKERS): + self.updates_queue.put(None) + + for i in self.updates_workers_list: + i.join() + + self.updates_workers_list.clear() + + for i in self.media_sessions.values(): + i.stop() self.is_started = False self.session.stop() - Syncer.remove(self) - def authorize_bot(self): try: r = self.send( @@ -724,9 +728,7 @@ class Client: media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None) - if media_type_str: - log.info("The file_id belongs to a {}".format(media_type_str)) - else: + if media_type_str is None: raise FileIdInvalid("Unknown media type: {}".format(unpacked[0])) file_name = file_name or getattr(media, "file_name", None) @@ -2971,43 +2973,45 @@ class Client: version: int = 0, size: int = None, progress: callable = None) -> str: - if dc_id != self.dc_id: - exported_auth = self.send( - functions.auth.ExportAuthorization( - dc_id=dc_id - ) - ) + with self.media_sessions_lock: + session = self.media_sessions.get(dc_id, None) - session = Session( - dc_id, - self.test_mode, - self.proxy, - Auth(dc_id, self.test_mode, self.proxy).create(), - self.api_id - ) - - session.start() - try: - session.send( - functions.auth.ImportAuthorization( - id=exported_auth.id, - bytes=exported_auth.bytes + if session is None: + if dc_id != self.dc_id: + exported_auth = self.send( + functions.auth.ExportAuthorization( + dc_id=dc_id + ) ) - ) - except Exception as e: - session.stop() - raise e - else: - session = Session( - dc_id, - self.test_mode, - self.proxy, - self.auth_key, - self.api_id - ) + session = Session( + dc_id, + self.test_mode, + self.proxy, + Auth(dc_id, self.test_mode, self.proxy).create(), + self.api_id + ) - session.start() + session.start() + + session.send( + functions.auth.ImportAuthorization( + id=exported_auth.id, + bytes=exported_auth.bytes + ) + ) + else: + session = Session( + dc_id, + self.test_mode, + self.proxy, + self.auth_key, + self.api_id + ) + + session.start() + + self.media_sessions[dc_id] = session if volume_id: # Photos are accessed by volume_id, local_id, secret location = types.InputFileLocation( @@ -3063,16 +3067,22 @@ class Client: ) elif isinstance(r, types.upload.FileCdnRedirect): - cdn_session = Session( - r.dc_id, - self.test_mode, - self.proxy, - Auth(r.dc_id, self.test_mode, self.proxy).create(), - self.api_id, - is_cdn=True - ) + with self.media_sessions_lock: + cdn_session = self.media_sessions.get(r.dc_id, None) - cdn_session.start() + if cdn_session is None: + cdn_session = Session( + r.dc_id, + self.test_mode, + self.proxy, + Auth(r.dc_id, self.test_mode, self.proxy).create(), + self.api_id, + is_cdn=True + ) + + cdn_session.start() + + self.media_sessions[r.dc_id] = cdn_session try: with tempfile.NamedTemporaryFile("wb", delete=False) as f: @@ -3136,7 +3146,8 @@ class Client: except Exception as e: raise e finally: - cdn_session.stop() + pass # Don't stop sessions, they are now cached and kept online + # cdn_session.stop() TODO: Remove this branch except Exception as e: log.error(e, exc_info=True) @@ -3149,7 +3160,8 @@ class Client: else: return file_name finally: - session.stop() + pass # Don't stop sessions, they are now cached and kept online + # session.stop() TODO: Remove this branch def join_chat(self, chat_id: str): """Use this method to join a group chat or channel.