2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-09-01 14:55:12 +00:00

Unify peers cache

This commit is contained in:
bakatrouble
2019-02-26 19:24:00 +03:00
parent 5dc33c6337
commit 260043d8ec
8 changed files with 122 additions and 105 deletions

View File

@@ -324,8 +324,7 @@ class Client(Methods, BaseClient):
now = time.time() now = time.time()
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP: if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
self.session_storage.peers_by_username.clear() self.session_storage.clear_cache()
self.session_storage.peers_by_phone.clear()
self.get_initial_dialogs() self.get_initial_dialogs()
self.get_contacts() self.get_contacts()
@@ -763,60 +762,7 @@ class Client(Methods, BaseClient):
types.Chat, types.ChatForbidden, types.Chat, types.ChatForbidden,
types.Channel, types.ChannelForbidden]]): types.Channel, types.ChannelForbidden]]):
for entity in entities: for entity in entities:
if isinstance(entity, types.User): self.session_storage.cache_peer(entity)
user_id = entity.id
access_hash = entity.access_hash
if access_hash is None:
continue
username = entity.username
phone = entity.phone
input_peer = types.InputPeerUser(
user_id=user_id,
access_hash=access_hash
)
self.session_storage.peers_by_id[user_id] = input_peer
if username is not None:
self.session_storage.peers_by_username[username.lower()] = input_peer
if phone is not None:
self.session_storage.peers_by_phone[phone] = input_peer
if isinstance(entity, (types.Chat, types.ChatForbidden)):
chat_id = entity.id
peer_id = -chat_id
input_peer = types.InputPeerChat(
chat_id=chat_id
)
self.session_storage.peers_by_id[peer_id] = input_peer
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
channel_id = entity.id
peer_id = int("-100" + str(channel_id))
access_hash = entity.access_hash
if access_hash is None:
continue
username = getattr(entity, "username", None)
input_peer = types.InputPeerChannel(
channel_id=channel_id,
access_hash=access_hash
)
self.session_storage.peers_by_id[peer_id] = input_peer
if username is not None:
self.session_storage.peers_by_username[username.lower()] = input_peer
def download_worker(self): def download_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
@@ -1261,7 +1207,7 @@ class Client(Methods, BaseClient):
log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id))) log.info("Total peers: {}".format(self.session_storage.peers_count()))
return r return r
def get_initial_dialogs(self): def get_initial_dialogs(self):
@@ -1297,7 +1243,7 @@ class Client(Methods, BaseClient):
``KeyError`` in case the peer doesn't exist in the internal database. ``KeyError`` in case the peer doesn't exist in the internal database.
""" """
try: try:
return self.session_storage.peers_by_id[peer_id] return self.session_storage.get_peer_by_id(peer_id)
except KeyError: except KeyError:
if type(peer_id) is str: if type(peer_id) is str:
if peer_id in ("self", "me"): if peer_id in ("self", "me"):
@@ -1308,17 +1254,19 @@ class Client(Methods, BaseClient):
try: try:
int(peer_id) int(peer_id)
except ValueError: except ValueError:
if peer_id not in self.session_storage.peers_by_username: try:
self.session_storage.get_peer_by_username(peer_id)
except KeyError:
self.send( self.send(
functions.contacts.ResolveUsername( functions.contacts.ResolveUsername(
username=peer_id username=peer_id
) )
) )
return self.session_storage.peers_by_username[peer_id] return self.session_storage.get_peer_by_username(peer_id)
else: else:
try: try:
return self.session_storage.peers_by_phone[peer_id] return self.session_storage.get_peer_by_phone(peer_id)
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
@@ -1345,7 +1293,7 @@ class Client(Methods, BaseClient):
) )
try: try:
return self.session_storage.peers_by_id[peer_id] return self.session_storage.get_peer_by_id(peer_id)
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid

View File

@@ -74,8 +74,8 @@ class BaseClient:
self.rnd_id = MsgId self.rnd_id = MsgId
self.channels_pts = {} self.channels_pts = {}
self.markdown = Markdown(self.session_storage.peers_by_id) self.markdown = Markdown(self.session_storage)
self.html = HTML(self.session_storage.peers_by_id) self.html = HTML(self.session_storage)
self.session = None self.session = None
self.media_sessions = {} self.media_sessions = {}

View File

@@ -44,5 +44,5 @@ class GetContacts(BaseClient):
log.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone))) log.info("Total contacts: {}".format(self.session_storage.contacts_count()))
return [pyrogram.User._parse(self, user) for user in contacts.users] return [pyrogram.User._parse(self, user) for user in contacts.users]

View File

@@ -17,9 +17,10 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc import abc
from typing import Type from typing import Type, Union
import pyrogram import pyrogram
from pyrogram.api import types
class SessionDoesNotExist(Exception): class SessionDoesNotExist(Exception):
@@ -102,17 +103,41 @@ class SessionStorage(abc.ABC):
def is_bot(self, val): def is_bot(self, val):
... ...
@property
@abc.abstractmethod @abc.abstractmethod
def peers_by_id(self): def clear_cache(self):
... ...
@property
@abc.abstractmethod @abc.abstractmethod
def peers_by_username(self): def cache_peer(self, entity: Union[types.User,
types.Chat, types.ChatForbidden,
types.Channel, types.ChannelForbidden]):
... ...
@property
@abc.abstractmethod @abc.abstractmethod
def peers_by_phone(self): def get_peer_by_id(self, val: int):
...
@abc.abstractmethod
def get_peer_by_username(self, val: str):
...
@abc.abstractmethod
def get_peer_by_phone(self, val: str):
...
def get_peer(self, peer_id: Union[int, str]):
if isinstance(peer_id, int):
return self.get_peer_by_id(peer_id)
else:
peer_id = peer_id.lstrip('+@')
if peer_id.isdigit():
return self.get_peer_by_phone(peer_id)
return self.get_peer_by_username(peer_id)
@abc.abstractmethod
def peers_count(self):
...
@abc.abstractmethod
def contacts_count(self):
... ...

View File

@@ -58,19 +58,19 @@ class JsonSessionStorage(MemorySessionStorage):
self._is_bot = s.get('is_bot', self._is_bot) self._is_bot = s.get('is_bot', self._is_bot)
for k, v in s.get("peers_by_id", {}).items(): for k, v in s.get("peers_by_id", {}).items():
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v) self._peers_cache['i' + k] = utils.get_input_peer(int(k), v)
for k, v in s.get("peers_by_username", {}).items(): for k, v in s.get("peers_by_username", {}).items():
peer = self._peers_by_id.get(v, None) try:
self._peers_cache['u' + k] = self.get_peer_by_id(v)
if peer: except KeyError:
self._peers_by_username[k] = peer pass
for k, v in s.get("peers_by_phone", {}).items(): for k, v in s.get("peers_by_phone", {}).items():
peer = self._peers_by_id.get(v, None) try:
self._peers_cache['p' + k] = self.get_peer_by_id(v)
if peer: except KeyError:
self._peers_by_phone[k] = peer pass
def save(self, sync=False): def save(self, sync=False):
file_path = self._get_file_name(self._session_name) file_path = self._get_file_name(self._session_name)
@@ -93,16 +93,19 @@ class JsonSessionStorage(MemorySessionStorage):
'date': self._date, 'date': self._date,
'is_bot': self._is_bot, 'is_bot': self._is_bot,
'peers_by_id': { 'peers_by_id': {
k: getattr(v, "access_hash", None) k[1:]: getattr(v, "access_hash", None)
for k, v in self._peers_by_id.copy().items() for k, v in self._peers_cache.copy().items()
if k[0] == 'i'
}, },
'peers_by_username': { 'peers_by_username': {
k: utils.get_peer_id(v) k[1:]: utils.get_peer_id(v)
for k, v in self._peers_by_username.copy().items() for k, v in self._peers_cache.copy().items()
if k[0] == 'u'
}, },
'peers_by_phone': { 'peers_by_phone': {
k: utils.get_peer_id(v) k[1:]: utils.get_peer_id(v)
for k, v in self._peers_by_phone.copy().items() for k, v in self._peers_cache.copy().items()
if k[0] == 'p'
} }
} }

View File

@@ -1,4 +1,5 @@
import pyrogram import pyrogram
from pyrogram.api import types
from . import SessionStorage, SessionDoesNotExist from . import SessionStorage, SessionDoesNotExist
@@ -11,9 +12,7 @@ class MemorySessionStorage(SessionStorage):
self._user_id = None self._user_id = None
self._date = 0 self._date = 0
self._is_bot = False self._is_bot = False
self._peers_by_id = {} self._peers_cache = {}
self._peers_by_username = {}
self._peers_by_phone = {}
def load(self): def load(self):
raise SessionDoesNotExist() raise SessionDoesNotExist()
@@ -72,14 +71,48 @@ class MemorySessionStorage(SessionStorage):
def is_bot(self, val): def is_bot(self, val):
self._is_bot = val self._is_bot = val
@property def clear_cache(self):
def peers_by_id(self): keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys()))
return self._peers_by_id for key in keys:
try:
del self._peers_cache[key]
except KeyError:
pass
@property def cache_peer(self, entity):
def peers_by_username(self): if isinstance(entity, types.User):
return self._peers_by_username input_peer = types.InputPeerUser(
user_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i' + str(entity.id)] = input_peer
if entity.username:
self._peers_cache['u' + entity.username.lower()] = input_peer
if entity.phone:
self._peers_cache['p' + entity.phone] = input_peer
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id)
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
input_peer = types.InputPeerChannel(
channel_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i-100' + str(entity.id)] = input_peer
username = getattr(entity, "username", None)
if username:
self._peers_cache['u' + username.lower()] = input_peer
@property def get_peer_by_id(self, val):
def peers_by_phone(self): return self._peers_cache['i' + str(val)]
return self._peers_by_phone
def get_peer_by_username(self, val):
return self._peers_cache['u' + val.lower()]
def get_peer_by_phone(self, val):
return self._peers_cache['p' + val]
def peers_count(self):
return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys())))
def contacts_count(self):
return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys())))

View File

@@ -29,14 +29,15 @@ from pyrogram.api.types import (
InputMessageEntityMentionName as Mention, InputMessageEntityMentionName as Mention,
) )
from . import utils from . import utils
from ..session_storage import SessionStorage
class HTML: class HTML:
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>") HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
MENTION_RE = re.compile(r"tg://user\?id=(\d+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, peers_by_id): def __init__(self, session_storage: SessionStorage):
self.peers_by_id = peers_by_id self.session_storage = session_storage
def parse(self, message: str): def parse(self, message: str):
entities = [] entities = []
@@ -52,7 +53,10 @@ class HTML:
if mention: if mention:
user_id = int(mention.group(1)) user_id = int(mention.group(1))
input_user = self.peers_by_id.get(user_id, None) try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = None
entity = ( entity = (
Mention(start, len(body), input_user) Mention(start, len(body), input_user)

View File

@@ -29,6 +29,7 @@ from pyrogram.api.types import (
InputMessageEntityMentionName as Mention InputMessageEntityMentionName as Mention
) )
from . import utils from . import utils
from ..session_storage import SessionStorage
class Markdown: class Markdown:
@@ -52,8 +53,8 @@ class Markdown:
)) ))
MENTION_RE = re.compile(r"tg://user\?id=(\d+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, peers_by_id: dict): def __init__(self, session_storage: SessionStorage):
self.peers_by_id = peers_by_id self.session_storage = session_storage
def parse(self, message: str): def parse(self, message: str):
message = utils.add_surrogates(str(message)).strip() message = utils.add_surrogates(str(message)).strip()
@@ -69,7 +70,10 @@ class Markdown:
if mention: if mention:
user_id = int(mention.group(1)) user_id = int(mention.group(1))
input_user = self.peers_by_id.get(user_id, None) try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = None
entity = ( entity = (
Mention(start, len(text), input_user) Mention(start, len(text), input_user)