mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 04:48:06 +00:00
Refactor session storages: use session_name arg to detect storage type
This commit is contained in:
parent
b04cf9ec92
commit
fd732add70
@ -49,12 +49,15 @@ from pyrogram.api.errors import (
|
||||
from pyrogram.client.handlers import DisconnectHandler
|
||||
from pyrogram.client.handlers.handler import Handler
|
||||
from pyrogram.client.methods.password.utils import compute_check
|
||||
from pyrogram.client.session_storage import BaseSessionConfig
|
||||
from pyrogram.crypto import AES
|
||||
from pyrogram.session import Auth, Session
|
||||
from .dispatcher import Dispatcher
|
||||
from .ext import utils, Syncer, BaseClient
|
||||
from .methods import Methods
|
||||
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
|
||||
from .session_storage import SessionDoesNotExist
|
||||
from .session_storage.json_session_storage import JsonSessionStorage
|
||||
from .session_storage.string_session_storage import StringSessionStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session_name: str,
|
||||
session_name: Union[str, BaseSessionConfig],
|
||||
api_id: Union[int, str] = None,
|
||||
api_hash: str = None,
|
||||
app_version: str = None,
|
||||
@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
|
||||
config_file: str = BaseClient.CONFIG_FILE,
|
||||
plugins: dict = None,
|
||||
no_updates: bool = None,
|
||||
takeout: bool = None,
|
||||
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
|
||||
super().__init__(session_storage_cls(self))
|
||||
takeout: bool = None):
|
||||
|
||||
self.session_name = session_name
|
||||
if isinstance(session_name, str):
|
||||
if session_name.startswith(':'):
|
||||
session_storage = StringSessionStorage(self, session_name)
|
||||
else:
|
||||
session_storage = JsonSessionStorage(self, session_name)
|
||||
elif isinstance(session_name, BaseSessionConfig):
|
||||
session_storage = session_name.session_storage_cls(self, session_name)
|
||||
else:
|
||||
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
|
||||
|
||||
super().__init__(session_storage)
|
||||
|
||||
self.session_name = str(session_name) # TODO: build correct session name
|
||||
self.api_id = int(api_id) if api_id else None
|
||||
self.api_hash = api_hash
|
||||
self.app_version = app_version
|
||||
@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
|
||||
|
||||
def load_session(self):
|
||||
try:
|
||||
self.session_storage.load_session(self.session_name)
|
||||
self.session_storage.load_session()
|
||||
except SessionDoesNotExist:
|
||||
session_name = self.session_name[:32]
|
||||
if session_name != self.session_name:
|
||||
session_name += '...'
|
||||
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
|
||||
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
|
||||
def load_plugins(self):
|
||||
@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
|
||||
log.warning('No plugin loaded from "{}"'.format(root))
|
||||
|
||||
def save_session(self):
|
||||
self.session_storage.save_session(self.session_name)
|
||||
self.session_storage.save_session()
|
||||
|
||||
def get_initial_dialogs_chunk(self,
|
||||
offset_date: int = 0):
|
||||
|
@ -83,10 +83,10 @@ class Syncer:
|
||||
def sync(cls, client):
|
||||
client.date = int(time.time())
|
||||
try:
|
||||
client.session_storage.save_session(client.session_name, sync=True)
|
||||
client.session_storage.save_session(sync=True)
|
||||
except Exception as e:
|
||||
log.critical(e, exc_info=True)
|
||||
else:
|
||||
log.info("Synced {}".format(client.session_name))
|
||||
finally:
|
||||
client.session_storage.sync_cleanup(client.session_name)
|
||||
client.session_storage.sync_cleanup()
|
||||
|
@ -17,6 +17,4 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from .session_storage_mixin import SessionStorageMixin
|
||||
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
|
||||
from .json_session_storage import JsonSessionStorage
|
||||
from .string_session_storage import StringSessionStorage
|
||||
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
|
||||
|
@ -17,6 +17,7 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import abc
|
||||
from typing import Type
|
||||
|
||||
import pyrogram
|
||||
|
||||
@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception):
|
||||
|
||||
|
||||
class BaseSessionStorage(abc.ABC):
|
||||
def __init__(self, client: 'pyrogram.client.BaseClient'):
|
||||
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
|
||||
self.client = client
|
||||
self.session_data = session_data
|
||||
self.dc_id = 1
|
||||
self.test_mode = None
|
||||
self.auth_key = None
|
||||
@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC):
|
||||
self.peers_by_phone = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_session(self, name: str):
|
||||
def load_session(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_session(self, name: str, sync=False):
|
||||
def save_session(self, sync=False):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_cleanup(self, name: str):
|
||||
def sync_cleanup(self):
|
||||
...
|
||||
|
||||
|
||||
class BaseSessionConfig(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def session_storage_cls(self) -> Type[BaseSessionStorage]:
|
||||
...
|
||||
|
@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage):
|
||||
name += '.session'
|
||||
return os.path.join(self.client.workdir, name)
|
||||
|
||||
def load_session(self, name: str):
|
||||
file_path = self._get_file_name(name)
|
||||
def load_session(self):
|
||||
file_path = self._get_file_name(self.session_data)
|
||||
log.info('Loading JSON session from {}'.format(file_path))
|
||||
|
||||
try:
|
||||
@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage):
|
||||
if peer:
|
||||
self.peers_by_phone[k] = peer
|
||||
|
||||
def save_session(self, name: str, sync=False):
|
||||
file_path = self._get_file_name(name)
|
||||
def save_session(self, sync=False):
|
||||
file_path = self._get_file_name(self.session_data)
|
||||
|
||||
if sync:
|
||||
file_path += '.tmp'
|
||||
@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage):
|
||||
|
||||
# execution won't be here if an error has occurred earlier
|
||||
if sync:
|
||||
shutil.move(file_path, self._get_file_name(name))
|
||||
shutil.move(file_path, self._get_file_name(self.session_data))
|
||||
|
||||
def sync_cleanup(self, name: str):
|
||||
def sync_cleanup(self):
|
||||
try:
|
||||
os.remove(self._get_file_name(name) + '.tmp')
|
||||
os.remove(self._get_file_name(self.session_data) + '.tmp')
|
||||
except OSError:
|
||||
pass
|
||||
|
@ -5,34 +5,33 @@ import struct
|
||||
from . import BaseSessionStorage, SessionDoesNotExist
|
||||
|
||||
|
||||
def StringSessionStorage(print_session: bool = False):
|
||||
class StringSessionStorageClass(BaseSessionStorage):
|
||||
"""
|
||||
Packs session data as following (forcing little-endian byte order):
|
||||
Char dc_id (1 byte, unsigned)
|
||||
Boolean test_mode (1 byte)
|
||||
Long long user_id (8 bytes, signed)
|
||||
Bytes auth_key (256 bytes)
|
||||
class StringSessionStorage(BaseSessionStorage):
|
||||
"""
|
||||
Packs session data as following (forcing little-endian byte order):
|
||||
Char dc_id (1 byte, unsigned)
|
||||
Boolean test_mode (1 byte)
|
||||
Long long user_id (8 bytes, signed)
|
||||
Bytes auth_key (256 bytes)
|
||||
|
||||
Uses Base64 encoding for printable representation
|
||||
"""
|
||||
PACK_FORMAT = '<B?q256s'
|
||||
Uses Base64 encoding for printable representation
|
||||
"""
|
||||
PACK_FORMAT = '<B?q256s'
|
||||
|
||||
def load_session(self, session_string: str):
|
||||
try:
|
||||
decoded = base64.b64decode(session_string)
|
||||
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
|
||||
except (struct.error, binascii.Error):
|
||||
raise SessionDoesNotExist()
|
||||
def load_session(self):
|
||||
try:
|
||||
session_string = self.session_data[1:]
|
||||
session_string += '=' * (4 - len(session_string) % 4) # restore padding
|
||||
decoded = base64.b64decode(session_string, b'-_')
|
||||
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
|
||||
except (struct.error, binascii.Error):
|
||||
raise SessionDoesNotExist()
|
||||
|
||||
def save_session(self, session_string: str, sync=False):
|
||||
if print_session and not sync:
|
||||
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
|
||||
encoded = base64.b64encode(packed).decode('latin-1')
|
||||
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
|
||||
print('Created session string:\n{}'.format(split))
|
||||
def save_session(self, sync=False):
|
||||
if not sync:
|
||||
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
|
||||
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
|
||||
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
|
||||
print('Created session string:\n{}'.format(split))
|
||||
|
||||
def sync_cleanup(self, session_string: str):
|
||||
pass
|
||||
|
||||
return StringSessionStorageClass
|
||||
def sync_cleanup(self):
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user