2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Refactor session storages: use session_name arg to detect storage type

This commit is contained in:
bakatrouble 2019-02-22 00:03:58 +03:00
parent b04cf9ec92
commit fd732add70
6 changed files with 71 additions and 55 deletions

View File

@ -49,12 +49,15 @@ from pyrogram.api.errors import (
from pyrogram.client.handlers import DisconnectHandler from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check from pyrogram.client.methods.password.utils import compute_check
from pyrogram.client.session_storage import BaseSessionConfig
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient from .ext import utils, Syncer, BaseClient
from .methods import Methods from .methods import Methods
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist from .session_storage import SessionDoesNotExist
from .session_storage.json_session_storage import JsonSessionStorage
from .session_storage.string_session_storage import StringSessionStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
""" """
def __init__(self, def __init__(self,
session_name: str, session_name: Union[str, BaseSessionConfig],
api_id: Union[int, str] = None, api_id: Union[int, str] = None,
api_hash: str = None, api_hash: str = None,
app_version: str = None, app_version: str = None,
@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
config_file: str = BaseClient.CONFIG_FILE, config_file: str = BaseClient.CONFIG_FILE,
plugins: dict = None, plugins: dict = None,
no_updates: bool = None, no_updates: bool = None,
takeout: bool = None, takeout: bool = None):
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
super().__init__(session_storage_cls(self))
self.session_name = session_name if isinstance(session_name, str):
if session_name.startswith(':'):
session_storage = StringSessionStorage(self, session_name)
else:
session_storage = JsonSessionStorage(self, session_name)
elif isinstance(session_name, BaseSessionConfig):
session_storage = session_name.session_storage_cls(self, session_name)
else:
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
super().__init__(session_storage)
self.session_name = str(session_name) # TODO: build correct session name
self.api_id = int(api_id) if api_id else None self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash self.api_hash = api_hash
self.app_version = app_version self.app_version = app_version
@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
def load_session(self): def load_session(self):
try: try:
self.session_storage.load_session(self.session_name) self.session_storage.load_session()
except SessionDoesNotExist: except SessionDoesNotExist:
session_name = self.session_name[:32] log.info('Could not load session "{}", initiate new one'.format(self.session_name))
if session_name != self.session_name:
session_name += '...'
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create() self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
def load_plugins(self): def load_plugins(self):
@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root)) log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self): def save_session(self):
self.session_storage.save_session(self.session_name) self.session_storage.save_session()
def get_initial_dialogs_chunk(self, def get_initial_dialogs_chunk(self,
offset_date: int = 0): offset_date: int = 0):

View File

@ -83,10 +83,10 @@ class Syncer:
def sync(cls, client): def sync(cls, client):
client.date = int(time.time()) client.date = int(time.time())
try: try:
client.session_storage.save_session(client.session_name, sync=True) client.session_storage.save_session(sync=True)
except Exception as e: except Exception as e:
log.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:
log.info("Synced {}".format(client.session_name)) log.info("Synced {}".format(client.session_name))
finally: finally:
client.session_storage.sync_cleanup(client.session_name) client.session_storage.sync_cleanup()

View File

@ -17,6 +17,4 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .session_storage_mixin import SessionStorageMixin from .session_storage_mixin import SessionStorageMixin
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
from .json_session_storage import JsonSessionStorage
from .string_session_storage import StringSessionStorage

View File

@ -17,6 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc import abc
from typing import Type
import pyrogram import pyrogram
@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception):
class BaseSessionStorage(abc.ABC): class BaseSessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient'): def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
self.client = client self.client = client
self.session_data = session_data
self.dc_id = 1 self.dc_id = 1
self.test_mode = None self.test_mode = None
self.auth_key = None self.auth_key = None
@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC):
self.peers_by_phone = {} self.peers_by_phone = {}
@abc.abstractmethod @abc.abstractmethod
def load_session(self, name: str): def load_session(self):
... ...
@abc.abstractmethod @abc.abstractmethod
def save_session(self, name: str, sync=False): def save_session(self, sync=False):
... ...
@abc.abstractmethod @abc.abstractmethod
def sync_cleanup(self, name: str): def sync_cleanup(self):
...
class BaseSessionConfig(abc.ABC):
@property
@abc.abstractmethod
def session_storage_cls(self) -> Type[BaseSessionStorage]:
... ...

View File

@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage):
name += '.session' name += '.session'
return os.path.join(self.client.workdir, name) return os.path.join(self.client.workdir, name)
def load_session(self, name: str): def load_session(self):
file_path = self._get_file_name(name) file_path = self._get_file_name(self.session_data)
log.info('Loading JSON session from {}'.format(file_path)) log.info('Loading JSON session from {}'.format(file_path))
try: try:
@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage):
if peer: if peer:
self.peers_by_phone[k] = peer self.peers_by_phone[k] = peer
def save_session(self, name: str, sync=False): def save_session(self, sync=False):
file_path = self._get_file_name(name) file_path = self._get_file_name(self.session_data)
if sync: if sync:
file_path += '.tmp' file_path += '.tmp'
@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage):
# execution won't be here if an error has occurred earlier # execution won't be here if an error has occurred earlier
if sync: if sync:
shutil.move(file_path, self._get_file_name(name)) shutil.move(file_path, self._get_file_name(self.session_data))
def sync_cleanup(self, name: str): def sync_cleanup(self):
try: try:
os.remove(self._get_file_name(name) + '.tmp') os.remove(self._get_file_name(self.session_data) + '.tmp')
except OSError: except OSError:
pass pass

View File

@ -5,8 +5,7 @@ import struct
from . import BaseSessionStorage, SessionDoesNotExist from . import BaseSessionStorage, SessionDoesNotExist
def StringSessionStorage(print_session: bool = False): class StringSessionStorage(BaseSessionStorage):
class StringSessionStorageClass(BaseSessionStorage):
""" """
Packs session data as following (forcing little-endian byte order): Packs session data as following (forcing little-endian byte order):
Char dc_id (1 byte, unsigned) Char dc_id (1 byte, unsigned)
@ -18,21 +17,21 @@ def StringSessionStorage(print_session: bool = False):
""" """
PACK_FORMAT = '<B?q256s' PACK_FORMAT = '<B?q256s'
def load_session(self, session_string: str): def load_session(self):
try: try:
decoded = base64.b64decode(session_string) session_string = self.session_data[1:]
session_string += '=' * (4 - len(session_string) % 4) # restore padding
decoded = base64.b64decode(session_string, b'-_')
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded) self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
except (struct.error, binascii.Error): except (struct.error, binascii.Error):
raise SessionDoesNotExist() raise SessionDoesNotExist()
def save_session(self, session_string: str, sync=False): def save_session(self, sync=False):
if print_session and not sync: if not sync:
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key) packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
encoded = base64.b64encode(packed).decode('latin-1') encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)]) split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
print('Created session string:\n{}'.format(split)) print('Created session string:\n{}'.format(split))
def sync_cleanup(self, session_string: str): def sync_cleanup(self):
pass pass
return StringSessionStorageClass