2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-09-10 03:05:58 +00:00

Refactor session storages: use session_name arg to detect storage type

This commit is contained in:
bakatrouble
2019-02-22 00:03:58 +03:00
parent b04cf9ec92
commit fd732add70
6 changed files with 71 additions and 55 deletions

View File

@@ -49,12 +49,15 @@ from pyrogram.api.errors import (
from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check
from pyrogram.client.session_storage import BaseSessionConfig
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient
from .methods import Methods
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
from .session_storage import SessionDoesNotExist
from .session_storage.json_session_storage import JsonSessionStorage
from .session_storage.string_session_storage import StringSessionStorage
log = logging.getLogger(__name__)
@@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
"""
def __init__(self,
session_name: str,
session_name: Union[str, BaseSessionConfig],
api_id: Union[int, str] = None,
api_hash: str = None,
app_version: str = None,
@@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
config_file: str = BaseClient.CONFIG_FILE,
plugins: dict = None,
no_updates: bool = None,
takeout: bool = None,
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
super().__init__(session_storage_cls(self))
takeout: bool = None):
self.session_name = session_name
if isinstance(session_name, str):
if session_name.startswith(':'):
session_storage = StringSessionStorage(self, session_name)
else:
session_storage = JsonSessionStorage(self, session_name)
elif isinstance(session_name, BaseSessionConfig):
session_storage = session_name.session_storage_cls(self, session_name)
else:
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
super().__init__(session_storage)
self.session_name = str(session_name) # TODO: build correct session name
self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash
self.app_version = app_version
@@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
def load_session(self):
try:
self.session_storage.load_session(self.session_name)
self.session_storage.load_session()
except SessionDoesNotExist:
session_name = self.session_name[:32]
if session_name != self.session_name:
session_name += '...'
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
def load_plugins(self):
@@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self):
self.session_storage.save_session(self.session_name)
self.session_storage.save_session()
def get_initial_dialogs_chunk(self,
offset_date: int = 0):