mirror of
https://github.com/pyrogram/pyrogram
synced 2025-09-10 03:05:58 +00:00
Refactor session storages: use session_name arg to detect storage type
This commit is contained in:
@@ -49,12 +49,15 @@ from pyrogram.api.errors import (
|
||||
from pyrogram.client.handlers import DisconnectHandler
|
||||
from pyrogram.client.handlers.handler import Handler
|
||||
from pyrogram.client.methods.password.utils import compute_check
|
||||
from pyrogram.client.session_storage import BaseSessionConfig
|
||||
from pyrogram.crypto import AES
|
||||
from pyrogram.session import Auth, Session
|
||||
from .dispatcher import Dispatcher
|
||||
from .ext import utils, Syncer, BaseClient
|
||||
from .methods import Methods
|
||||
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
|
||||
from .session_storage import SessionDoesNotExist
|
||||
from .session_storage.json_session_storage import JsonSessionStorage
|
||||
from .session_storage.string_session_storage import StringSessionStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session_name: str,
|
||||
session_name: Union[str, BaseSessionConfig],
|
||||
api_id: Union[int, str] = None,
|
||||
api_hash: str = None,
|
||||
app_version: str = None,
|
||||
@@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
|
||||
config_file: str = BaseClient.CONFIG_FILE,
|
||||
plugins: dict = None,
|
||||
no_updates: bool = None,
|
||||
takeout: bool = None,
|
||||
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
|
||||
super().__init__(session_storage_cls(self))
|
||||
takeout: bool = None):
|
||||
|
||||
self.session_name = session_name
|
||||
if isinstance(session_name, str):
|
||||
if session_name.startswith(':'):
|
||||
session_storage = StringSessionStorage(self, session_name)
|
||||
else:
|
||||
session_storage = JsonSessionStorage(self, session_name)
|
||||
elif isinstance(session_name, BaseSessionConfig):
|
||||
session_storage = session_name.session_storage_cls(self, session_name)
|
||||
else:
|
||||
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
|
||||
|
||||
super().__init__(session_storage)
|
||||
|
||||
self.session_name = str(session_name) # TODO: build correct session name
|
||||
self.api_id = int(api_id) if api_id else None
|
||||
self.api_hash = api_hash
|
||||
self.app_version = app_version
|
||||
@@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
|
||||
|
||||
def load_session(self):
|
||||
try:
|
||||
self.session_storage.load_session(self.session_name)
|
||||
self.session_storage.load_session()
|
||||
except SessionDoesNotExist:
|
||||
session_name = self.session_name[:32]
|
||||
if session_name != self.session_name:
|
||||
session_name += '...'
|
||||
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
|
||||
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
|
||||
def load_plugins(self):
|
||||
@@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
|
||||
log.warning('No plugin loaded from "{}"'.format(root))
|
||||
|
||||
def save_session(self):
|
||||
self.session_storage.save_session(self.session_name)
|
||||
self.session_storage.save_session()
|
||||
|
||||
def get_initial_dialogs_chunk(self,
|
||||
offset_date: int = 0):
|
||||
|
Reference in New Issue
Block a user