diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 42bd73d6..ad755977 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -324,8 +324,7 @@ class Client(Methods, BaseClient): now = time.time() if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP: - self.session_storage.peers_by_username.clear() - self.session_storage.peers_by_phone.clear() + self.session_storage.clear_cache() self.get_initial_dialogs() self.get_contacts() @@ -763,60 +762,7 @@ class Client(Methods, BaseClient): types.Chat, types.ChatForbidden, types.Channel, types.ChannelForbidden]]): for entity in entities: - if isinstance(entity, types.User): - user_id = entity.id - - access_hash = entity.access_hash - - if access_hash is None: - continue - - username = entity.username - phone = entity.phone - - input_peer = types.InputPeerUser( - user_id=user_id, - access_hash=access_hash - ) - - self.session_storage.peers_by_id[user_id] = input_peer - - if username is not None: - self.session_storage.peers_by_username[username.lower()] = input_peer - - if phone is not None: - self.session_storage.peers_by_phone[phone] = input_peer - - if isinstance(entity, (types.Chat, types.ChatForbidden)): - chat_id = entity.id - peer_id = -chat_id - - input_peer = types.InputPeerChat( - chat_id=chat_id - ) - - self.session_storage.peers_by_id[peer_id] = input_peer - - if isinstance(entity, (types.Channel, types.ChannelForbidden)): - channel_id = entity.id - peer_id = int("-100" + str(channel_id)) - - access_hash = entity.access_hash - - if access_hash is None: - continue - - username = getattr(entity, "username", None) - - input_peer = types.InputPeerChannel( - channel_id=channel_id, - access_hash=access_hash - ) - - self.session_storage.peers_by_id[peer_id] = input_peer - - if username is not None: - self.session_storage.peers_by_username[username.lower()] = input_peer + self.session_storage.cache_peer(entity) def download_worker(self): name = threading.current_thread().name @@ -1261,7 +1207,7 @@ class Client(Methods, BaseClient): log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) time.sleep(e.x) else: - log.info("Total peers: {}".format(len(self.session_storage.peers_by_id))) + log.info("Total peers: {}".format(self.session_storage.peers_count())) return r def get_initial_dialogs(self): @@ -1297,7 +1243,7 @@ class Client(Methods, BaseClient): ``KeyError`` in case the peer doesn't exist in the internal database. """ try: - return self.session_storage.peers_by_id[peer_id] + return self.session_storage.get_peer_by_id(peer_id) except KeyError: if type(peer_id) is str: if peer_id in ("self", "me"): @@ -1308,17 +1254,19 @@ class Client(Methods, BaseClient): try: int(peer_id) except ValueError: - if peer_id not in self.session_storage.peers_by_username: + try: + self.session_storage.get_peer_by_username(peer_id) + except KeyError: self.send( functions.contacts.ResolveUsername( username=peer_id ) ) - return self.session_storage.peers_by_username[peer_id] + return self.session_storage.get_peer_by_username(peer_id) else: try: - return self.session_storage.peers_by_phone[peer_id] + return self.session_storage.get_peer_by_phone(peer_id) except KeyError: raise PeerIdInvalid @@ -1345,7 +1293,7 @@ class Client(Methods, BaseClient): ) try: - return self.session_storage.peers_by_id[peer_id] + return self.session_storage.get_peer_by_id(peer_id) except KeyError: raise PeerIdInvalid diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index 732a600f..1ec65c93 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -74,8 +74,8 @@ class BaseClient: self.rnd_id = MsgId self.channels_pts = {} - self.markdown = Markdown(self.session_storage.peers_by_id) - self.html = HTML(self.session_storage.peers_by_id) + self.markdown = Markdown(self.session_storage) + self.html = HTML(self.session_storage) self.session = None self.media_sessions = {} diff --git a/pyrogram/client/methods/contacts/get_contacts.py b/pyrogram/client/methods/contacts/get_contacts.py index 35b24592..12419106 100644 --- a/pyrogram/client/methods/contacts/get_contacts.py +++ b/pyrogram/client/methods/contacts/get_contacts.py @@ -44,5 +44,5 @@ class GetContacts(BaseClient): log.warning("get_contacts flood: waiting {} seconds".format(e.x)) time.sleep(e.x) else: - log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone))) + log.info("Total contacts: {}".format(self.session_storage.contacts_count())) return [pyrogram.User._parse(self, user) for user in contacts.users] diff --git a/pyrogram/client/session_storage/abstract.py b/pyrogram/client/session_storage/abstract.py index e8f4441e..39517a01 100644 --- a/pyrogram/client/session_storage/abstract.py +++ b/pyrogram/client/session_storage/abstract.py @@ -17,9 +17,10 @@ # along with Pyrogram. If not, see . import abc -from typing import Type +from typing import Type, Union import pyrogram +from pyrogram.api import types class SessionDoesNotExist(Exception): @@ -102,17 +103,41 @@ class SessionStorage(abc.ABC): def is_bot(self, val): ... - @property @abc.abstractmethod - def peers_by_id(self): + def clear_cache(self): ... - @property @abc.abstractmethod - def peers_by_username(self): + def cache_peer(self, entity: Union[types.User, + types.Chat, types.ChatForbidden, + types.Channel, types.ChannelForbidden]): ... - @property @abc.abstractmethod - def peers_by_phone(self): + def get_peer_by_id(self, val: int): + ... + + @abc.abstractmethod + def get_peer_by_username(self, val: str): + ... + + @abc.abstractmethod + def get_peer_by_phone(self, val: str): + ... + + def get_peer(self, peer_id: Union[int, str]): + if isinstance(peer_id, int): + return self.get_peer_by_id(peer_id) + else: + peer_id = peer_id.lstrip('+@') + if peer_id.isdigit(): + return self.get_peer_by_phone(peer_id) + return self.get_peer_by_username(peer_id) + + @abc.abstractmethod + def peers_count(self): + ... + + @abc.abstractmethod + def contacts_count(self): ... diff --git a/pyrogram/client/session_storage/json.py b/pyrogram/client/session_storage/json.py index 170089a4..aaa6b96f 100644 --- a/pyrogram/client/session_storage/json.py +++ b/pyrogram/client/session_storage/json.py @@ -58,19 +58,19 @@ class JsonSessionStorage(MemorySessionStorage): self._is_bot = s.get('is_bot', self._is_bot) for k, v in s.get("peers_by_id", {}).items(): - self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v) + self._peers_cache['i' + k] = utils.get_input_peer(int(k), v) for k, v in s.get("peers_by_username", {}).items(): - peer = self._peers_by_id.get(v, None) - - if peer: - self._peers_by_username[k] = peer + try: + self._peers_cache['u' + k] = self.get_peer_by_id(v) + except KeyError: + pass for k, v in s.get("peers_by_phone", {}).items(): - peer = self._peers_by_id.get(v, None) - - if peer: - self._peers_by_phone[k] = peer + try: + self._peers_cache['p' + k] = self.get_peer_by_id(v) + except KeyError: + pass def save(self, sync=False): file_path = self._get_file_name(self._session_name) @@ -93,16 +93,19 @@ class JsonSessionStorage(MemorySessionStorage): 'date': self._date, 'is_bot': self._is_bot, 'peers_by_id': { - k: getattr(v, "access_hash", None) - for k, v in self._peers_by_id.copy().items() + k[1:]: getattr(v, "access_hash", None) + for k, v in self._peers_cache.copy().items() + if k[0] == 'i' }, 'peers_by_username': { - k: utils.get_peer_id(v) - for k, v in self._peers_by_username.copy().items() + k[1:]: utils.get_peer_id(v) + for k, v in self._peers_cache.copy().items() + if k[0] == 'u' }, 'peers_by_phone': { - k: utils.get_peer_id(v) - for k, v in self._peers_by_phone.copy().items() + k[1:]: utils.get_peer_id(v) + for k, v in self._peers_cache.copy().items() + if k[0] == 'p' } } diff --git a/pyrogram/client/session_storage/memory.py b/pyrogram/client/session_storage/memory.py index f456f8eb..d5f92f0d 100644 --- a/pyrogram/client/session_storage/memory.py +++ b/pyrogram/client/session_storage/memory.py @@ -1,4 +1,5 @@ import pyrogram +from pyrogram.api import types from . import SessionStorage, SessionDoesNotExist @@ -11,9 +12,7 @@ class MemorySessionStorage(SessionStorage): self._user_id = None self._date = 0 self._is_bot = False - self._peers_by_id = {} - self._peers_by_username = {} - self._peers_by_phone = {} + self._peers_cache = {} def load(self): raise SessionDoesNotExist() @@ -72,14 +71,48 @@ class MemorySessionStorage(SessionStorage): def is_bot(self, val): self._is_bot = val - @property - def peers_by_id(self): - return self._peers_by_id + def clear_cache(self): + keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys())) + for key in keys: + try: + del self._peers_cache[key] + except KeyError: + pass - @property - def peers_by_username(self): - return self._peers_by_username + def cache_peer(self, entity): + if isinstance(entity, types.User): + input_peer = types.InputPeerUser( + user_id=entity.id, + access_hash=entity.access_hash + ) + self._peers_cache['i' + str(entity.id)] = input_peer + if entity.username: + self._peers_cache['u' + entity.username.lower()] = input_peer + if entity.phone: + self._peers_cache['p' + entity.phone] = input_peer + elif isinstance(entity, (types.Chat, types.ChatForbidden)): + self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id) + elif isinstance(entity, (types.Channel, types.ChannelForbidden)): + input_peer = types.InputPeerChannel( + channel_id=entity.id, + access_hash=entity.access_hash + ) + self._peers_cache['i-100' + str(entity.id)] = input_peer + username = getattr(entity, "username", None) + if username: + self._peers_cache['u' + username.lower()] = input_peer - @property - def peers_by_phone(self): - return self._peers_by_phone + def get_peer_by_id(self, val): + return self._peers_cache['i' + str(val)] + + def get_peer_by_username(self, val): + return self._peers_cache['u' + val.lower()] + + def get_peer_by_phone(self, val): + return self._peers_cache['p' + val] + + def peers_count(self): + return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys()))) + + def contacts_count(self): + return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys()))) diff --git a/pyrogram/client/style/html.py b/pyrogram/client/style/html.py index 9a72a565..88e317cd 100644 --- a/pyrogram/client/style/html.py +++ b/pyrogram/client/style/html.py @@ -29,14 +29,15 @@ from pyrogram.api.types import ( InputMessageEntityMentionName as Mention, ) from . import utils +from ..session_storage import SessionStorage class HTML: HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)") - def __init__(self, peers_by_id): - self.peers_by_id = peers_by_id + def __init__(self, session_storage: SessionStorage): + self.session_storage = session_storage def parse(self, message: str): entities = [] @@ -52,7 +53,10 @@ class HTML: if mention: user_id = int(mention.group(1)) - input_user = self.peers_by_id.get(user_id, None) + try: + input_user = self.session_storage.get_peer_by_id(user_id) + except KeyError: + input_user = None entity = ( Mention(start, len(body), input_user) diff --git a/pyrogram/client/style/markdown.py b/pyrogram/client/style/markdown.py index 05a11a25..6793b643 100644 --- a/pyrogram/client/style/markdown.py +++ b/pyrogram/client/style/markdown.py @@ -29,6 +29,7 @@ from pyrogram.api.types import ( InputMessageEntityMentionName as Mention ) from . import utils +from ..session_storage import SessionStorage class Markdown: @@ -52,8 +53,8 @@ class Markdown: )) MENTION_RE = re.compile(r"tg://user\?id=(\d+)") - def __init__(self, peers_by_id: dict): - self.peers_by_id = peers_by_id + def __init__(self, session_storage: SessionStorage): + self.session_storage = session_storage def parse(self, message: str): message = utils.add_surrogates(str(message)).strip() @@ -69,7 +70,10 @@ class Markdown: if mention: user_id = int(mention.group(1)) - input_user = self.peers_by_id.get(user_id, None) + try: + input_user = self.session_storage.get_peer_by_id(user_id) + except KeyError: + input_user = None entity = ( Mention(start, len(text), input_user)