2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-29 05:18:10 +00:00

Cache media sessions. Related to #40. Fixes #68

This commit is contained in:
Dan 2018-05-01 18:15:33 +02:00
parent 943691fd18
commit 83febf2e0c

View File

@ -195,6 +195,8 @@ class Client:
self.html = HTML(self.peers_by_id) self.html = HTML(self.peers_by_id)
self.session = None self.session = None
self.media_sessions = {}
self.media_sessions_lock = threading.Lock()
self.is_started = None self.is_started = None
self.is_idle = None self.is_idle = None
@ -399,13 +401,8 @@ class Client:
if not self.is_started: if not self.is_started:
raise ConnectionError("Client is already stopped") raise ConnectionError("Client is already stopped")
for _ in range(self.UPDATES_WORKERS): Syncer.remove(self)
self.updates_queue.put(None) self.dispatcher.stop()
for i in self.updates_workers_list:
i.join()
self.updates_workers_list.clear()
for _ in range(self.DOWNLOAD_WORKERS): for _ in range(self.DOWNLOAD_WORKERS):
self.download_queue.put(None) self.download_queue.put(None)
@ -415,13 +412,20 @@ class Client:
self.download_workers_list.clear() self.download_workers_list.clear()
self.dispatcher.stop() for _ in range(self.UPDATES_WORKERS):
self.updates_queue.put(None)
for i in self.updates_workers_list:
i.join()
self.updates_workers_list.clear()
for i in self.media_sessions.values():
i.stop()
self.is_started = False self.is_started = False
self.session.stop() self.session.stop()
Syncer.remove(self)
def authorize_bot(self): def authorize_bot(self):
try: try:
r = self.send( r = self.send(
@ -724,9 +728,7 @@ class Client:
media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None) media_type_str = Client.MEDIA_TYPE_ID.get(media_type, None)
if media_type_str: if media_type_str is None:
log.info("The file_id belongs to a {}".format(media_type_str))
else:
raise FileIdInvalid("Unknown media type: {}".format(unpacked[0])) raise FileIdInvalid("Unknown media type: {}".format(unpacked[0]))
file_name = file_name or getattr(media, "file_name", None) file_name = file_name or getattr(media, "file_name", None)
@ -2971,6 +2973,10 @@ class Client:
version: int = 0, version: int = 0,
size: int = None, size: int = None,
progress: callable = None) -> str: progress: callable = None) -> str:
with self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != self.dc_id: if dc_id != self.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
@ -2987,17 +2993,13 @@ class Client:
) )
session.start() session.start()
try:
session.send( session.send(
functions.auth.ImportAuthorization( functions.auth.ImportAuthorization(
id=exported_auth.id, id=exported_auth.id,
bytes=exported_auth.bytes bytes=exported_auth.bytes
) )
) )
except Exception as e:
session.stop()
raise e
else: else:
session = Session( session = Session(
dc_id, dc_id,
@ -3009,6 +3011,8 @@ class Client:
session.start() session.start()
self.media_sessions[dc_id] = session
if volume_id: # Photos are accessed by volume_id, local_id, secret if volume_id: # Photos are accessed by volume_id, local_id, secret
location = types.InputFileLocation( location = types.InputFileLocation(
volume_id=volume_id, volume_id=volume_id,
@ -3063,6 +3067,10 @@ class Client:
) )
elif isinstance(r, types.upload.FileCdnRedirect): elif isinstance(r, types.upload.FileCdnRedirect):
with self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
cdn_session = Session( cdn_session = Session(
r.dc_id, r.dc_id,
self.test_mode, self.test_mode,
@ -3074,6 +3082,8 @@ class Client:
cdn_session.start() cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
try: try:
with tempfile.NamedTemporaryFile("wb", delete=False) as f: with tempfile.NamedTemporaryFile("wb", delete=False) as f:
file_name = f.name file_name = f.name
@ -3136,7 +3146,8 @@ class Client:
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
cdn_session.stop() pass # Don't stop sessions, they are now cached and kept online
# cdn_session.stop() TODO: Remove this branch
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
@ -3149,7 +3160,8 @@ class Client:
else: else:
return file_name return file_name
finally: finally:
session.stop() pass # Don't stop sessions, they are now cached and kept online
# session.stop() TODO: Remove this branch
def join_chat(self, chat_id: str): def join_chat(self, chat_id: str):
"""Use this method to join a group chat or channel. """Use this method to join a group chat or channel.