2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Update pyrogram/client to accommodate Storage Engines

This commit is contained in:
Dan 2019-06-19 16:10:37 +02:00
parent edaced35a7
commit 30192de1ad
4 changed files with 141 additions and 114 deletions

View File

@ -16,7 +16,6 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import binascii
import logging import logging
import math import math
import mimetypes import mimetypes
@ -51,10 +50,7 @@ from pyrogram.errors import (
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from .ext import utils, Syncer, BaseClient, Dispatcher from .ext import utils, Syncer, BaseClient, Dispatcher
from .methods import Methods from .methods import Methods
from .session_storage import ( from .storage import Storage, FileStorage, MemoryStorage
SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage,
StringSessionStorage, SQLiteSessionStorage
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -64,8 +60,13 @@ class Client(Methods, BaseClient):
Parameters: Parameters:
session_name (``str``): session_name (``str``):
Name to uniquely identify a session of either a User or a Bot, e.g.: "my_account". This name will be used Pass a string of your choice to give a name to the client session, e.g.: "*my_account*". This name will be
to save a file to disk that stores details needed for reconnecting without asking again for credentials. used to save a file on disk that stores details needed to reconnect without asking again for credentials.
Alternatively, if you don't want a file to be saved on disk, pass the special name "**:memory:**" to start
an in-memory session that will be discarded as soon as you stop the Client. In order to reconnect again
using a memory storage without having to login again, you can use
:meth:`~pyrogram.Client.export_session_string` before stopping the client to get a session string you can
pass here as argument.
api_id (``int``, *optional*): api_id (``int``, *optional*):
The *api_id* part of your Telegram API Key, as integer. E.g.: 12345 The *api_id* part of your Telegram API Key, as integer. E.g.: 12345
@ -179,7 +180,7 @@ class Client(Methods, BaseClient):
def __init__( def __init__(
self, self,
session_name: str, session_name: Union[str, Storage],
api_id: Union[int, str] = None, api_id: Union[int, str] = None,
api_hash: str = None, api_hash: str = None,
app_version: str = None, app_version: str = None,
@ -226,12 +227,23 @@ class Client(Methods, BaseClient):
self.first_name = first_name self.first_name = first_name
self.last_name = last_name self.last_name = last_name
self.workers = workers self.workers = workers
self.workdir = workdir self.workdir = Path(workdir)
self.config_file = config_file self.config_file = Path(config_file)
self.plugins = plugins self.plugins = plugins
self.no_updates = no_updates self.no_updates = no_updates
self.takeout = takeout self.takeout = takeout
if isinstance(session_name, str):
if session_name == ":memory:" or len(session_name) >= MemoryStorage.SESSION_STRING_SIZE:
session_name = re.sub(r"[\n\s]+", "", session_name)
self.storage = MemoryStorage(session_name)
else:
self.storage = FileStorage(session_name, self.workdir)
elif isinstance(session_name, Storage):
self.storage = session_name
else:
raise ValueError("Unknown storage engine")
self.dispatcher = Dispatcher(self, workers) self.dispatcher = Dispatcher(self, workers)
def __enter__(self): def __enter__(self):
@ -266,50 +278,32 @@ class Client(Methods, BaseClient):
if self.is_started: if self.is_started:
raise ConnectionError("Client has already been started") raise ConnectionError("Client has already been started")
if isinstance(self.session_storage, JsonSessionStorage):
if self.BOT_TOKEN_RE.match(self.session_storage._session_name):
self.session_storage.is_bot = True
self.bot_token = self.session_storage._session_name
self.session_storage._session_name = self.session_storage._session_name.split(":")[0]
warnings.warn('\nWARNING: You are using a bot token as session name!\n'
'This usage will be deprecated soon. Please use a session file name to load '
'an existing session and the bot_token argument to create new sessions.\n'
'More info: https://docs.pyrogram.org/intro/auth#bot-authorization\n')
self.load_config() self.load_config()
self.load_session() self.load_session()
self.load_plugins() self.load_plugins()
self.session = Session( self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
self,
self.session_storage.dc_id,
self.session_storage.auth_key
)
self.session.start() self.session.start()
self.is_started = True self.is_started = True
try: try:
if self.session_storage.user_id is None: if self.storage.user_id is None:
if self.bot_token is None: if self.bot_token is None:
self.is_bot = False self.storage.is_bot = False
self.authorize_user() self.authorize_user()
else: else:
self.session_storage.is_bot = True self.storage.is_bot = True
self.authorize_bot() self.authorize_bot()
self.save_session() if not self.storage.is_bot:
if not self.session_storage.is_bot:
if self.takeout: if self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
log.warning("Takeout session {} initiated".format(self.takeout_id)) log.warning("Takeout session {} initiated".format(self.takeout_id))
now = time.time() now = time.time()
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP: if abs(now - self.storage.date) > Client.OFFLINE_SLEEP:
self.session_storage.clear_cache()
self.get_initial_dialogs() self.get_initial_dialogs()
self.get_contacts() self.get_contacts()
else: else:
@ -508,20 +502,15 @@ class Client(Methods, BaseClient):
except UserMigrate as e: except UserMigrate as e:
self.session.stop() self.session.stop()
self.session_storage.dc_id = e.x self.storage.dc_id = e.x
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode, self.storage.auth_key = Auth(self, self.storage.dc_id).create()
self.ipv6, self._proxy).create() self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
self.session = Session(
self,
self.session_storage.dc_id,
self.session_storage.auth_key
)
self.session.start() self.session.start()
self.authorize_bot() self.authorize_bot()
else: else:
self.session_storage.user_id = r.user.id self.storage.user_id = r.user.id
print("Logged in successfully as @{}".format(r.user.username)) print("Logged in successfully as @{}".format(r.user.username))
@ -562,20 +551,10 @@ class Client(Methods, BaseClient):
except (PhoneMigrate, NetworkMigrate) as e: except (PhoneMigrate, NetworkMigrate) as e:
self.session.stop() self.session.stop()
self.session_storage.dc_id = e.x self.storage.dc_id = e.x
self.storage.auth_key = Auth(self, self.storage.dc_id).create()
self.session_storage.auth_key = Auth( self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
self.session_storage.dc_id,
self.session_storage.test_mode,
self.ipv6,
self._proxy
).create()
self.session = Session(
self,
self.session_storage.dc_id,
self.session_storage.auth_key
)
self.session.start() self.session.start()
except (PhoneNumberInvalid, PhoneNumberBanned) as e: except (PhoneNumberInvalid, PhoneNumberBanned) as e:
@ -755,13 +734,13 @@ class Client(Methods, BaseClient):
) )
self.password = None self.password = None
self.session_storage.user_id = r.user.id self.storage.user_id = r.user.id
print("Logged in successfully as {}".format(r.user.first_name)) print("Logged in successfully as {}".format(r.user.first_name))
def fetch_peers( def fetch_peers(
self, self,
entities: List[ peers: List[
Union[ Union[
types.User, types.User,
types.Chat, types.ChatForbidden, types.Chat, types.ChatForbidden,
@ -770,11 +749,57 @@ class Client(Methods, BaseClient):
] ]
) -> bool: ) -> bool:
is_min = False is_min = False
parsed_peers = []
for entity in entities: for peer in peers:
if isinstance(entity, (types.User, types.Channel, types.ChannelForbidden)) and not entity.access_hash: username = None
phone_number = None
if isinstance(peer, types.User):
peer_id = peer.id
access_hash = peer.access_hash
username = peer.username
phone_number = peer.phone
if peer.bot:
peer_type = "bot"
else:
peer_type = "user"
if access_hash is None:
is_min = True
continue
if username is not None:
username = username.lower()
elif isinstance(peer, (types.Chat, types.ChatForbidden)):
peer_id = -peer.id
access_hash = 0
peer_type = "group"
elif isinstance(peer, (types.Channel, types.ChannelForbidden)):
peer_id = int("-100" + str(peer.id))
access_hash = peer.access_hash
username = getattr(peer, "username", None)
if peer.broadcast:
peer_type = "channel"
else:
peer_type = "supergroup"
if access_hash is None:
is_min = True
continue
if username is not None:
username = username.lower()
else:
continue continue
self.session_storage.cache_peer(entity)
parsed_peers.append((peer_id, access_hash, peer_type, username, phone_number))
self.storage.update_peers(parsed_peers)
return is_min return is_min
@ -1035,12 +1060,23 @@ class Client(Methods, BaseClient):
self.plugins = None self.plugins = None
def load_session(self): def load_session(self):
try: self.storage.open()
self.session_storage.load()
except SessionDoesNotExist: session_empty = any([
log.info('Could not load session "{}", initiate new one'.format(self.session_name)) self.storage.test_mode is None,
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode, self.storage.auth_key is None,
self.ipv6, self._proxy).create() self.storage.user_id is None,
self.storage.is_bot is None
])
if session_empty:
self.storage.dc_id = 1
self.storage.date = 0
self.storage.test_mode = self.test_mode
self.storage.auth_key = Auth(self, self.storage.dc_id).create()
self.storage.user_id = None
self.storage.is_bot = None
def load_plugins(self): def load_plugins(self):
if self.plugins: if self.plugins:
@ -1164,9 +1200,6 @@ class Client(Methods, BaseClient):
log.warning('[{}] No plugin loaded from "{}"'.format( log.warning('[{}] No plugin loaded from "{}"'.format(
self.session_name, root)) self.session_name, root))
def save_session(self):
self.session_storage.save()
def get_initial_dialogs_chunk(self, offset_date: int = 0): def get_initial_dialogs_chunk(self, offset_date: int = 0):
while True: while True:
try: try:
@ -1184,7 +1217,7 @@ class Client(Methods, BaseClient):
log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total peers: {}".format(self.session_storage.peers_count())) log.info("Total peers: {}".format(self.storage.peers_count))
return r return r
def get_initial_dialogs(self): def get_initial_dialogs(self):
@ -1222,7 +1255,7 @@ class Client(Methods, BaseClient):
KeyError: In case the peer doesn't exist in the internal database. KeyError: In case the peer doesn't exist in the internal database.
""" """
try: try:
return self.session_storage.get_peer_by_id(peer_id) return self.storage.get_peer_by_id(peer_id)
except KeyError: except KeyError:
if type(peer_id) is str: if type(peer_id) is str:
if peer_id in ("self", "me"): if peer_id in ("self", "me"):
@ -1234,7 +1267,7 @@ class Client(Methods, BaseClient):
int(peer_id) int(peer_id)
except ValueError: except ValueError:
try: try:
self.session_storage.get_peer_by_username(peer_id) return self.storage.get_peer_by_username(peer_id)
except KeyError: except KeyError:
self.send( self.send(
functions.contacts.ResolveUsername( functions.contacts.ResolveUsername(
@ -1242,10 +1275,10 @@ class Client(Methods, BaseClient):
) )
) )
return self.session_storage.get_peer_by_username(peer_id) return self.storage.get_peer_by_username(peer_id)
else: else:
try: try:
return self.session_storage.get_peer_by_phone(peer_id) return self.storage.get_peer_by_phone_number(peer_id)
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
@ -1253,7 +1286,10 @@ class Client(Methods, BaseClient):
self.fetch_peers( self.fetch_peers(
self.send( self.send(
functions.users.GetUsers( functions.users.GetUsers(
id=[types.InputUser(user_id=peer_id, access_hash=0)] id=[types.InputUser(
user_id=peer_id,
access_hash=0
)]
) )
) )
) )
@ -1261,7 +1297,10 @@ class Client(Methods, BaseClient):
if str(peer_id).startswith("-100"): if str(peer_id).startswith("-100"):
self.send( self.send(
functions.channels.GetChannels( functions.channels.GetChannels(
id=[types.InputChannel(channel_id=int(str(peer_id)[4:]), access_hash=0)] id=[types.InputChannel(
channel_id=int(str(peer_id)[4:]),
access_hash=0
)]
) )
) )
else: else:
@ -1272,7 +1311,7 @@ class Client(Methods, BaseClient):
) )
try: try:
return self.session_storage.get_peer_by_id(peer_id) return self.storage.get_peer_by_id(peer_id)
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
@ -1347,7 +1386,7 @@ class Client(Methods, BaseClient):
file_id = file_id or self.rnd_id() file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self, self.session_storage.dc_id, self.session_storage.auth_key, is_media=True) session = Session(self, self.storage.dc_id, self.storage.auth_key, is_media=True)
session.start() session.start()
try: try:
@ -1433,19 +1472,14 @@ class Client(Methods, BaseClient):
session = self.media_sessions.get(dc_id, None) session = self.media_sessions.get(dc_id, None)
if session is None: if session is None:
if dc_id != self.session_storage.dc_id: if dc_id != self.storage.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
dc_id=dc_id dc_id=dc_id
) )
) )
session = Session( session = Session(self, dc_id, Auth(self, dc_id).create(), is_media=True)
self,
dc_id,
Auth(dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
is_media=True
)
session.start() session.start()
@ -1458,12 +1492,7 @@ class Client(Methods, BaseClient):
) )
) )
else: else:
session = Session( session = Session(self, dc_id, self.storage.auth_key, is_media=True)
self,
dc_id,
self.session_storage.auth_key,
is_media=True
)
session.start() session.start()
@ -1548,13 +1577,7 @@ class Client(Methods, BaseClient):
cdn_session = self.media_sessions.get(r.dc_id, None) cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None: if cdn_session is None:
cdn_session = Session( cdn_session = Session(self, r.dc_id, Auth(self, r.dc_id).create(), is_media=True, is_cdn=True)
self,
r.dc_id,
Auth(r.dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
is_media=True,
is_cdn=True
)
cdn_session.start() cdn_session.start()
@ -1650,3 +1673,11 @@ class Client(Methods, BaseClient):
if extensions: if extensions:
return extensions.split(" ")[0] return extensions.split(" ")[0]
def export_session_string(self):
"""Export the current session as serialized string.
Returns:
``str``: The session serialized into a printable, url-safe string.
"""
return self.storage.export_session_string()

View File

@ -87,13 +87,13 @@ class BaseClient:
mime_types_to_extensions[mime_type] = " ".join(extensions) mime_types_to_extensions[mime_type] = " ".join(extensions)
def __init__(self, session_storage: SessionStorage): def __init__(self):
self.session_storage = session_storage self.storage = None
self.rnd_id = MsgId self.rnd_id = MsgId
self.markdown = Markdown(self.session_storage, self) self.markdown = Markdown(self)
self.html = HTML(self.session_storage, self) self.html = HTML(self)
self.session = None self.session = None
self.media_sessions = {} self.media_sessions = {}

View File

@ -16,16 +16,10 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import base64
import json
import logging import logging
import os
import shutil
import time import time
from threading import Thread, Event, Lock from threading import Thread, Event, Lock
from . import utils
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -81,10 +75,13 @@ class Syncer:
@classmethod @classmethod
def sync(cls, client): def sync(cls, client):
client.session_storage.date = int(time.time())
try: try:
client.session_storage.save(sync=True) start = time.time()
client.storage.save()
except Exception as e: except Exception as e:
log.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:
log.info("Synced {}".format(client.session_name)) log.info('Synced "{}" in {:.6} ms'.format(
client.storage.name,
(time.time() - start) * 1000
))

View File

@ -46,5 +46,4 @@ class GetContacts(BaseClient):
log.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total contacts: {}".format(self.session_storage.contacts_count()))
return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users) return pyrogram.List(pyrogram.User._parse(self, user) for user in contacts.users)