From fd732add7062409d3efc2f0f30666a0e31f61d2c Mon Sep 17 00:00:00 2001 From: bakatrouble Date: Fri, 22 Feb 2019 00:03:58 +0300 Subject: [PATCH] Refactor session storages: use session_name arg to detect storage type --- pyrogram/client/client.py | 34 +++++++----- pyrogram/client/ext/syncer.py | 4 +- pyrogram/client/session_storage/__init__.py | 4 +- .../session_storage/base_session_storage.py | 17 ++++-- .../session_storage/json_session_storage.py | 14 ++--- .../session_storage/string_session_storage.py | 53 +++++++++---------- 6 files changed, 71 insertions(+), 55 deletions(-) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 0e8d5554..f17a054b 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -49,12 +49,15 @@ from pyrogram.api.errors import ( from pyrogram.client.handlers import DisconnectHandler from pyrogram.client.handlers.handler import Handler from pyrogram.client.methods.password.utils import compute_check +from pyrogram.client.session_storage import BaseSessionConfig from pyrogram.crypto import AES from pyrogram.session import Auth, Session from .dispatcher import Dispatcher from .ext import utils, Syncer, BaseClient from .methods import Methods -from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist +from .session_storage import SessionDoesNotExist +from .session_storage.json_session_storage import JsonSessionStorage +from .session_storage.string_session_storage import StringSessionStorage log = logging.getLogger(__name__) @@ -176,7 +179,7 @@ class Client(Methods, BaseClient): """ def __init__(self, - session_name: str, + session_name: Union[str, BaseSessionConfig], api_id: Union[int, str] = None, api_hash: str = None, app_version: str = None, @@ -198,11 +201,21 @@ class Client(Methods, BaseClient): config_file: str = BaseClient.CONFIG_FILE, plugins: dict = None, no_updates: bool = None, - takeout: bool = None, - session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage): - super().__init__(session_storage_cls(self)) + takeout: bool = None): - self.session_name = session_name + if isinstance(session_name, str): + if session_name.startswith(':'): + session_storage = StringSessionStorage(self, session_name) + else: + session_storage = JsonSessionStorage(self, session_name) + elif isinstance(session_name, BaseSessionConfig): + session_storage = session_name.session_storage_cls(self, session_name) + else: + raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass') + + super().__init__(session_storage) + + self.session_name = str(session_name) # TODO: build correct session name self.api_id = int(api_id) if api_id else None self.api_hash = api_hash self.app_version = app_version @@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient): def load_session(self): try: - self.session_storage.load_session(self.session_name) + self.session_storage.load_session() except SessionDoesNotExist: - session_name = self.session_name[:32] - if session_name != self.session_name: - session_name += '...' - log.info('Could not load session "{}", initializing new one'.format(self.session_name)) + log.info('Could not load session "{}", initiate new one'.format(self.session_name)) self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create() def load_plugins(self): @@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient): log.warning('No plugin loaded from "{}"'.format(root)) def save_session(self): - self.session_storage.save_session(self.session_name) + self.session_storage.save_session() def get_initial_dialogs_chunk(self, offset_date: int = 0): diff --git a/pyrogram/client/ext/syncer.py b/pyrogram/client/ext/syncer.py index 8930b13e..70955624 100644 --- a/pyrogram/client/ext/syncer.py +++ b/pyrogram/client/ext/syncer.py @@ -83,10 +83,10 @@ class Syncer: def sync(cls, client): client.date = int(time.time()) try: - client.session_storage.save_session(client.session_name, sync=True) + client.session_storage.save_session(sync=True) except Exception as e: log.critical(e, exc_info=True) else: log.info("Synced {}".format(client.session_name)) finally: - client.session_storage.sync_cleanup(client.session_name) + client.session_storage.sync_cleanup() diff --git a/pyrogram/client/session_storage/__init__.py b/pyrogram/client/session_storage/__init__.py index ced103ce..611ec9b7 100644 --- a/pyrogram/client/session_storage/__init__.py +++ b/pyrogram/client/session_storage/__init__.py @@ -17,6 +17,4 @@ # along with Pyrogram. If not, see . from .session_storage_mixin import SessionStorageMixin -from .base_session_storage import BaseSessionStorage, SessionDoesNotExist -from .json_session_storage import JsonSessionStorage -from .string_session_storage import StringSessionStorage +from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist diff --git a/pyrogram/client/session_storage/base_session_storage.py b/pyrogram/client/session_storage/base_session_storage.py index 75e416b4..a5c879f1 100644 --- a/pyrogram/client/session_storage/base_session_storage.py +++ b/pyrogram/client/session_storage/base_session_storage.py @@ -17,6 +17,7 @@ # along with Pyrogram. If not, see . import abc +from typing import Type import pyrogram @@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception): class BaseSessionStorage(abc.ABC): - def __init__(self, client: 'pyrogram.client.BaseClient'): + def __init__(self, client: 'pyrogram.client.BaseClient', session_data): self.client = client + self.session_data = session_data self.dc_id = 1 self.test_mode = None self.auth_key = None @@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC): self.peers_by_phone = {} @abc.abstractmethod - def load_session(self, name: str): + def load_session(self): ... @abc.abstractmethod - def save_session(self, name: str, sync=False): + def save_session(self, sync=False): ... @abc.abstractmethod - def sync_cleanup(self, name: str): + def sync_cleanup(self): + ... + + +class BaseSessionConfig(abc.ABC): + @property + @abc.abstractmethod + def session_storage_cls(self) -> Type[BaseSessionStorage]: ... diff --git a/pyrogram/client/session_storage/json_session_storage.py b/pyrogram/client/session_storage/json_session_storage.py index 679a21f3..f41091af 100644 --- a/pyrogram/client/session_storage/json_session_storage.py +++ b/pyrogram/client/session_storage/json_session_storage.py @@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage): name += '.session' return os.path.join(self.client.workdir, name) - def load_session(self, name: str): - file_path = self._get_file_name(name) + def load_session(self): + file_path = self._get_file_name(self.session_data) log.info('Loading JSON session from {}'.format(file_path)) try: @@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage): if peer: self.peers_by_phone[k] = peer - def save_session(self, name: str, sync=False): - file_path = self._get_file_name(name) + def save_session(self, sync=False): + file_path = self._get_file_name(self.session_data) if sync: file_path += '.tmp' @@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage): # execution won't be here if an error has occurred earlier if sync: - shutil.move(file_path, self._get_file_name(name)) + shutil.move(file_path, self._get_file_name(self.session_data)) - def sync_cleanup(self, name: str): + def sync_cleanup(self): try: - os.remove(self._get_file_name(name) + '.tmp') + os.remove(self._get_file_name(self.session_data) + '.tmp') except OSError: pass diff --git a/pyrogram/client/session_storage/string_session_storage.py b/pyrogram/client/session_storage/string_session_storage.py index 9b6ebf0e..c01a2b35 100644 --- a/pyrogram/client/session_storage/string_session_storage.py +++ b/pyrogram/client/session_storage/string_session_storage.py @@ -5,34 +5,33 @@ import struct from . import BaseSessionStorage, SessionDoesNotExist -def StringSessionStorage(print_session: bool = False): - class StringSessionStorageClass(BaseSessionStorage): - """ - Packs session data as following (forcing little-endian byte order): - Char dc_id (1 byte, unsigned) - Boolean test_mode (1 byte) - Long long user_id (8 bytes, signed) - Bytes auth_key (256 bytes) +class StringSessionStorage(BaseSessionStorage): + """ + Packs session data as following (forcing little-endian byte order): + Char dc_id (1 byte, unsigned) + Boolean test_mode (1 byte) + Long long user_id (8 bytes, signed) + Bytes auth_key (256 bytes) - Uses Base64 encoding for printable representation - """ - PACK_FORMAT = '