mirror of
https://github.com/pyrogram/pyrogram
synced 2025-09-01 14:55:12 +00:00
Unify peers cache
This commit is contained in:
@@ -324,8 +324,7 @@ class Client(Methods, BaseClient):
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
|
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
|
||||||
self.session_storage.peers_by_username.clear()
|
self.session_storage.clear_cache()
|
||||||
self.session_storage.peers_by_phone.clear()
|
|
||||||
|
|
||||||
self.get_initial_dialogs()
|
self.get_initial_dialogs()
|
||||||
self.get_contacts()
|
self.get_contacts()
|
||||||
@@ -763,60 +762,7 @@ class Client(Methods, BaseClient):
|
|||||||
types.Chat, types.ChatForbidden,
|
types.Chat, types.ChatForbidden,
|
||||||
types.Channel, types.ChannelForbidden]]):
|
types.Channel, types.ChannelForbidden]]):
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if isinstance(entity, types.User):
|
self.session_storage.cache_peer(entity)
|
||||||
user_id = entity.id
|
|
||||||
|
|
||||||
access_hash = entity.access_hash
|
|
||||||
|
|
||||||
if access_hash is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
username = entity.username
|
|
||||||
phone = entity.phone
|
|
||||||
|
|
||||||
input_peer = types.InputPeerUser(
|
|
||||||
user_id=user_id,
|
|
||||||
access_hash=access_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session_storage.peers_by_id[user_id] = input_peer
|
|
||||||
|
|
||||||
if username is not None:
|
|
||||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
|
||||||
|
|
||||||
if phone is not None:
|
|
||||||
self.session_storage.peers_by_phone[phone] = input_peer
|
|
||||||
|
|
||||||
if isinstance(entity, (types.Chat, types.ChatForbidden)):
|
|
||||||
chat_id = entity.id
|
|
||||||
peer_id = -chat_id
|
|
||||||
|
|
||||||
input_peer = types.InputPeerChat(
|
|
||||||
chat_id=chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
|
||||||
|
|
||||||
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
|
||||||
channel_id = entity.id
|
|
||||||
peer_id = int("-100" + str(channel_id))
|
|
||||||
|
|
||||||
access_hash = entity.access_hash
|
|
||||||
|
|
||||||
if access_hash is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
username = getattr(entity, "username", None)
|
|
||||||
|
|
||||||
input_peer = types.InputPeerChannel(
|
|
||||||
channel_id=channel_id,
|
|
||||||
access_hash=access_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
|
||||||
|
|
||||||
if username is not None:
|
|
||||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
|
||||||
|
|
||||||
def download_worker(self):
|
def download_worker(self):
|
||||||
name = threading.current_thread().name
|
name = threading.current_thread().name
|
||||||
@@ -1261,7 +1207,7 @@ class Client(Methods, BaseClient):
|
|||||||
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
||||||
time.sleep(e.x)
|
time.sleep(e.x)
|
||||||
else:
|
else:
|
||||||
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
|
log.info("Total peers: {}".format(self.session_storage.peers_count()))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def get_initial_dialogs(self):
|
def get_initial_dialogs(self):
|
||||||
@@ -1297,7 +1243,7 @@ class Client(Methods, BaseClient):
|
|||||||
``KeyError`` in case the peer doesn't exist in the internal database.
|
``KeyError`` in case the peer doesn't exist in the internal database.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.session_storage.peers_by_id[peer_id]
|
return self.session_storage.get_peer_by_id(peer_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if type(peer_id) is str:
|
if type(peer_id) is str:
|
||||||
if peer_id in ("self", "me"):
|
if peer_id in ("self", "me"):
|
||||||
@@ -1308,17 +1254,19 @@ class Client(Methods, BaseClient):
|
|||||||
try:
|
try:
|
||||||
int(peer_id)
|
int(peer_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if peer_id not in self.session_storage.peers_by_username:
|
try:
|
||||||
|
self.session_storage.get_peer_by_username(peer_id)
|
||||||
|
except KeyError:
|
||||||
self.send(
|
self.send(
|
||||||
functions.contacts.ResolveUsername(
|
functions.contacts.ResolveUsername(
|
||||||
username=peer_id
|
username=peer_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.session_storage.peers_by_username[peer_id]
|
return self.session_storage.get_peer_by_username(peer_id)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return self.session_storage.peers_by_phone[peer_id]
|
return self.session_storage.get_peer_by_phone(peer_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise PeerIdInvalid
|
raise PeerIdInvalid
|
||||||
|
|
||||||
@@ -1345,7 +1293,7 @@ class Client(Methods, BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.session_storage.peers_by_id[peer_id]
|
return self.session_storage.get_peer_by_id(peer_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise PeerIdInvalid
|
raise PeerIdInvalid
|
||||||
|
|
||||||
|
@@ -74,8 +74,8 @@ class BaseClient:
|
|||||||
self.rnd_id = MsgId
|
self.rnd_id = MsgId
|
||||||
self.channels_pts = {}
|
self.channels_pts = {}
|
||||||
|
|
||||||
self.markdown = Markdown(self.session_storage.peers_by_id)
|
self.markdown = Markdown(self.session_storage)
|
||||||
self.html = HTML(self.session_storage.peers_by_id)
|
self.html = HTML(self.session_storage)
|
||||||
|
|
||||||
self.session = None
|
self.session = None
|
||||||
self.media_sessions = {}
|
self.media_sessions = {}
|
||||||
|
@@ -44,5 +44,5 @@ class GetContacts(BaseClient):
|
|||||||
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
||||||
time.sleep(e.x)
|
time.sleep(e.x)
|
||||||
else:
|
else:
|
||||||
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
|
log.info("Total contacts: {}".format(self.session_storage.contacts_count()))
|
||||||
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
||||||
|
@@ -17,9 +17,10 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import Type
|
from typing import Type, Union
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
|
from pyrogram.api import types
|
||||||
|
|
||||||
|
|
||||||
class SessionDoesNotExist(Exception):
|
class SessionDoesNotExist(Exception):
|
||||||
@@ -102,17 +103,41 @@ class SessionStorage(abc.ABC):
|
|||||||
def is_bot(self, val):
|
def is_bot(self, val):
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def peers_by_id(self):
|
def clear_cache(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def peers_by_username(self):
|
def cache_peer(self, entity: Union[types.User,
|
||||||
|
types.Chat, types.ChatForbidden,
|
||||||
|
types.Channel, types.ChannelForbidden]):
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def peers_by_phone(self):
|
def get_peer_by_id(self, val: int):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_peer_by_username(self, val: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_peer_by_phone(self, val: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_peer(self, peer_id: Union[int, str]):
|
||||||
|
if isinstance(peer_id, int):
|
||||||
|
return self.get_peer_by_id(peer_id)
|
||||||
|
else:
|
||||||
|
peer_id = peer_id.lstrip('+@')
|
||||||
|
if peer_id.isdigit():
|
||||||
|
return self.get_peer_by_phone(peer_id)
|
||||||
|
return self.get_peer_by_username(peer_id)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def peers_count(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def contacts_count(self):
|
||||||
...
|
...
|
||||||
|
@@ -58,19 +58,19 @@ class JsonSessionStorage(MemorySessionStorage):
|
|||||||
self._is_bot = s.get('is_bot', self._is_bot)
|
self._is_bot = s.get('is_bot', self._is_bot)
|
||||||
|
|
||||||
for k, v in s.get("peers_by_id", {}).items():
|
for k, v in s.get("peers_by_id", {}).items():
|
||||||
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
self._peers_cache['i' + k] = utils.get_input_peer(int(k), v)
|
||||||
|
|
||||||
for k, v in s.get("peers_by_username", {}).items():
|
for k, v in s.get("peers_by_username", {}).items():
|
||||||
peer = self._peers_by_id.get(v, None)
|
try:
|
||||||
|
self._peers_cache['u' + k] = self.get_peer_by_id(v)
|
||||||
if peer:
|
except KeyError:
|
||||||
self._peers_by_username[k] = peer
|
pass
|
||||||
|
|
||||||
for k, v in s.get("peers_by_phone", {}).items():
|
for k, v in s.get("peers_by_phone", {}).items():
|
||||||
peer = self._peers_by_id.get(v, None)
|
try:
|
||||||
|
self._peers_cache['p' + k] = self.get_peer_by_id(v)
|
||||||
if peer:
|
except KeyError:
|
||||||
self._peers_by_phone[k] = peer
|
pass
|
||||||
|
|
||||||
def save(self, sync=False):
|
def save(self, sync=False):
|
||||||
file_path = self._get_file_name(self._session_name)
|
file_path = self._get_file_name(self._session_name)
|
||||||
@@ -93,16 +93,19 @@ class JsonSessionStorage(MemorySessionStorage):
|
|||||||
'date': self._date,
|
'date': self._date,
|
||||||
'is_bot': self._is_bot,
|
'is_bot': self._is_bot,
|
||||||
'peers_by_id': {
|
'peers_by_id': {
|
||||||
k: getattr(v, "access_hash", None)
|
k[1:]: getattr(v, "access_hash", None)
|
||||||
for k, v in self._peers_by_id.copy().items()
|
for k, v in self._peers_cache.copy().items()
|
||||||
|
if k[0] == 'i'
|
||||||
},
|
},
|
||||||
'peers_by_username': {
|
'peers_by_username': {
|
||||||
k: utils.get_peer_id(v)
|
k[1:]: utils.get_peer_id(v)
|
||||||
for k, v in self._peers_by_username.copy().items()
|
for k, v in self._peers_cache.copy().items()
|
||||||
|
if k[0] == 'u'
|
||||||
},
|
},
|
||||||
'peers_by_phone': {
|
'peers_by_phone': {
|
||||||
k: utils.get_peer_id(v)
|
k[1:]: utils.get_peer_id(v)
|
||||||
for k, v in self._peers_by_phone.copy().items()
|
for k, v in self._peers_cache.copy().items()
|
||||||
|
if k[0] == 'p'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import pyrogram
|
import pyrogram
|
||||||
|
from pyrogram.api import types
|
||||||
from . import SessionStorage, SessionDoesNotExist
|
from . import SessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
@@ -11,9 +12,7 @@ class MemorySessionStorage(SessionStorage):
|
|||||||
self._user_id = None
|
self._user_id = None
|
||||||
self._date = 0
|
self._date = 0
|
||||||
self._is_bot = False
|
self._is_bot = False
|
||||||
self._peers_by_id = {}
|
self._peers_cache = {}
|
||||||
self._peers_by_username = {}
|
|
||||||
self._peers_by_phone = {}
|
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
raise SessionDoesNotExist()
|
raise SessionDoesNotExist()
|
||||||
@@ -72,14 +71,48 @@ class MemorySessionStorage(SessionStorage):
|
|||||||
def is_bot(self, val):
|
def is_bot(self, val):
|
||||||
self._is_bot = val
|
self._is_bot = val
|
||||||
|
|
||||||
@property
|
def clear_cache(self):
|
||||||
def peers_by_id(self):
|
keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys()))
|
||||||
return self._peers_by_id
|
for key in keys:
|
||||||
|
try:
|
||||||
|
del self._peers_cache[key]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
def cache_peer(self, entity):
|
||||||
def peers_by_username(self):
|
if isinstance(entity, types.User):
|
||||||
return self._peers_by_username
|
input_peer = types.InputPeerUser(
|
||||||
|
user_id=entity.id,
|
||||||
|
access_hash=entity.access_hash
|
||||||
|
)
|
||||||
|
self._peers_cache['i' + str(entity.id)] = input_peer
|
||||||
|
if entity.username:
|
||||||
|
self._peers_cache['u' + entity.username.lower()] = input_peer
|
||||||
|
if entity.phone:
|
||||||
|
self._peers_cache['p' + entity.phone] = input_peer
|
||||||
|
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
|
||||||
|
self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id)
|
||||||
|
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
||||||
|
input_peer = types.InputPeerChannel(
|
||||||
|
channel_id=entity.id,
|
||||||
|
access_hash=entity.access_hash
|
||||||
|
)
|
||||||
|
self._peers_cache['i-100' + str(entity.id)] = input_peer
|
||||||
|
username = getattr(entity, "username", None)
|
||||||
|
if username:
|
||||||
|
self._peers_cache['u' + username.lower()] = input_peer
|
||||||
|
|
||||||
@property
|
def get_peer_by_id(self, val):
|
||||||
def peers_by_phone(self):
|
return self._peers_cache['i' + str(val)]
|
||||||
return self._peers_by_phone
|
|
||||||
|
def get_peer_by_username(self, val):
|
||||||
|
return self._peers_cache['u' + val.lower()]
|
||||||
|
|
||||||
|
def get_peer_by_phone(self, val):
|
||||||
|
return self._peers_cache['p' + val]
|
||||||
|
|
||||||
|
def peers_count(self):
|
||||||
|
return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys())))
|
||||||
|
|
||||||
|
def contacts_count(self):
|
||||||
|
return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys())))
|
||||||
|
@@ -29,14 +29,15 @@ from pyrogram.api.types import (
|
|||||||
InputMessageEntityMentionName as Mention,
|
InputMessageEntityMentionName as Mention,
|
||||||
)
|
)
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from ..session_storage import SessionStorage
|
||||||
|
|
||||||
|
|
||||||
class HTML:
|
class HTML:
|
||||||
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
|
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
|
||||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||||
|
|
||||||
def __init__(self, peers_by_id):
|
def __init__(self, session_storage: SessionStorage):
|
||||||
self.peers_by_id = peers_by_id
|
self.session_storage = session_storage
|
||||||
|
|
||||||
def parse(self, message: str):
|
def parse(self, message: str):
|
||||||
entities = []
|
entities = []
|
||||||
@@ -52,7 +53,10 @@ class HTML:
|
|||||||
|
|
||||||
if mention:
|
if mention:
|
||||||
user_id = int(mention.group(1))
|
user_id = int(mention.group(1))
|
||||||
input_user = self.peers_by_id.get(user_id, None)
|
try:
|
||||||
|
input_user = self.session_storage.get_peer_by_id(user_id)
|
||||||
|
except KeyError:
|
||||||
|
input_user = None
|
||||||
|
|
||||||
entity = (
|
entity = (
|
||||||
Mention(start, len(body), input_user)
|
Mention(start, len(body), input_user)
|
||||||
|
@@ -29,6 +29,7 @@ from pyrogram.api.types import (
|
|||||||
InputMessageEntityMentionName as Mention
|
InputMessageEntityMentionName as Mention
|
||||||
)
|
)
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from ..session_storage import SessionStorage
|
||||||
|
|
||||||
|
|
||||||
class Markdown:
|
class Markdown:
|
||||||
@@ -52,8 +53,8 @@ class Markdown:
|
|||||||
))
|
))
|
||||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||||
|
|
||||||
def __init__(self, peers_by_id: dict):
|
def __init__(self, session_storage: SessionStorage):
|
||||||
self.peers_by_id = peers_by_id
|
self.session_storage = session_storage
|
||||||
|
|
||||||
def parse(self, message: str):
|
def parse(self, message: str):
|
||||||
message = utils.add_surrogates(str(message)).strip()
|
message = utils.add_surrogates(str(message)).strip()
|
||||||
@@ -69,7 +70,10 @@ class Markdown:
|
|||||||
|
|
||||||
if mention:
|
if mention:
|
||||||
user_id = int(mention.group(1))
|
user_id = int(mention.group(1))
|
||||||
input_user = self.peers_by_id.get(user_id, None)
|
try:
|
||||||
|
input_user = self.session_storage.get_peer_by_id(user_id)
|
||||||
|
except KeyError:
|
||||||
|
input_user = None
|
||||||
|
|
||||||
entity = (
|
entity = (
|
||||||
Mention(start, len(text), input_user)
|
Mention(start, len(text), input_user)
|
||||||
|
Reference in New Issue
Block a user