2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 04:48:06 +00:00

Refactor session storages: use session_name arg to detect storage type

This commit is contained in:
bakatrouble 2019-02-22 00:03:58 +03:00
parent b04cf9ec92
commit fd732add70
6 changed files with 71 additions and 55 deletions

View File

@ -49,12 +49,15 @@ from pyrogram.api.errors import (
from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check
from pyrogram.client.session_storage import BaseSessionConfig
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient
from .methods import Methods
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
from .session_storage import SessionDoesNotExist
from .session_storage.json_session_storage import JsonSessionStorage
from .session_storage.string_session_storage import StringSessionStorage
log = logging.getLogger(__name__)
@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
"""
def __init__(self,
session_name: str,
session_name: Union[str, BaseSessionConfig],
api_id: Union[int, str] = None,
api_hash: str = None,
app_version: str = None,
@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
config_file: str = BaseClient.CONFIG_FILE,
plugins: dict = None,
no_updates: bool = None,
takeout: bool = None,
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
super().__init__(session_storage_cls(self))
takeout: bool = None):
self.session_name = session_name
if isinstance(session_name, str):
if session_name.startswith(':'):
session_storage = StringSessionStorage(self, session_name)
else:
session_storage = JsonSessionStorage(self, session_name)
elif isinstance(session_name, BaseSessionConfig):
session_storage = session_name.session_storage_cls(self, session_name)
else:
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
super().__init__(session_storage)
self.session_name = str(session_name) # TODO: build correct session name
self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash
self.app_version = app_version
@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
def load_session(self):
try:
self.session_storage.load_session(self.session_name)
self.session_storage.load_session()
except SessionDoesNotExist:
session_name = self.session_name[:32]
if session_name != self.session_name:
session_name += '...'
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
def load_plugins(self):
@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self):
self.session_storage.save_session(self.session_name)
self.session_storage.save_session()
def get_initial_dialogs_chunk(self,
offset_date: int = 0):

View File

@ -83,10 +83,10 @@ class Syncer:
def sync(cls, client):
client.date = int(time.time())
try:
client.session_storage.save_session(client.session_name, sync=True)
client.session_storage.save_session(sync=True)
except Exception as e:
log.critical(e, exc_info=True)
else:
log.info("Synced {}".format(client.session_name))
finally:
client.session_storage.sync_cleanup(client.session_name)
client.session_storage.sync_cleanup()

View File

@ -17,6 +17,4 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .session_storage_mixin import SessionStorageMixin
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
from .json_session_storage import JsonSessionStorage
from .string_session_storage import StringSessionStorage
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist

View File

@ -17,6 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc
from typing import Type
import pyrogram
@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception):
class BaseSessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient'):
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
self.client = client
self.session_data = session_data
self.dc_id = 1
self.test_mode = None
self.auth_key = None
@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC):
self.peers_by_phone = {}
@abc.abstractmethod
def load_session(self, name: str):
def load_session(self):
...
@abc.abstractmethod
def save_session(self, name: str, sync=False):
def save_session(self, sync=False):
...
@abc.abstractmethod
def sync_cleanup(self, name: str):
def sync_cleanup(self):
...
class BaseSessionConfig(abc.ABC):
@property
@abc.abstractmethod
def session_storage_cls(self) -> Type[BaseSessionStorage]:
...

View File

@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage):
name += '.session'
return os.path.join(self.client.workdir, name)
def load_session(self, name: str):
file_path = self._get_file_name(name)
def load_session(self):
file_path = self._get_file_name(self.session_data)
log.info('Loading JSON session from {}'.format(file_path))
try:
@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage):
if peer:
self.peers_by_phone[k] = peer
def save_session(self, name: str, sync=False):
file_path = self._get_file_name(name)
def save_session(self, sync=False):
file_path = self._get_file_name(self.session_data)
if sync:
file_path += '.tmp'
@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage):
# execution won't be here if an error has occurred earlier
if sync:
shutil.move(file_path, self._get_file_name(name))
shutil.move(file_path, self._get_file_name(self.session_data))
def sync_cleanup(self, name: str):
def sync_cleanup(self):
try:
os.remove(self._get_file_name(name) + '.tmp')
os.remove(self._get_file_name(self.session_data) + '.tmp')
except OSError:
pass

View File

@ -5,34 +5,33 @@ import struct
from . import BaseSessionStorage, SessionDoesNotExist
def StringSessionStorage(print_session: bool = False):
class StringSessionStorageClass(BaseSessionStorage):
"""
Packs session data as following (forcing little-endian byte order):
Char dc_id (1 byte, unsigned)
Boolean test_mode (1 byte)
Long long user_id (8 bytes, signed)
Bytes auth_key (256 bytes)
class StringSessionStorage(BaseSessionStorage):
"""
Packs session data as following (forcing little-endian byte order):
Char dc_id (1 byte, unsigned)
Boolean test_mode (1 byte)
Long long user_id (8 bytes, signed)
Bytes auth_key (256 bytes)
Uses Base64 encoding for printable representation
"""
PACK_FORMAT = '<B?q256s'
Uses Base64 encoding for printable representation
"""
PACK_FORMAT = '<B?q256s'
def load_session(self, session_string: str):
try:
decoded = base64.b64decode(session_string)
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
except (struct.error, binascii.Error):
raise SessionDoesNotExist()
def load_session(self):
try:
session_string = self.session_data[1:]
session_string += '=' * (4 - len(session_string) % 4) # restore padding
decoded = base64.b64decode(session_string, b'-_')
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
except (struct.error, binascii.Error):
raise SessionDoesNotExist()
def save_session(self, session_string: str, sync=False):
if print_session and not sync:
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
encoded = base64.b64encode(packed).decode('latin-1')
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
print('Created session string:\n{}'.format(split))
def save_session(self, sync=False):
if not sync:
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
print('Created session string:\n{}'.format(split))
def sync_cleanup(self, session_string: str):
pass
return StringSessionStorageClass
def sync_cleanup(self):
pass