mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 21:07:59 +00:00
Refactor session storages: use session_name arg to detect storage type
This commit is contained in:
parent
b04cf9ec92
commit
fd732add70
@ -49,12 +49,15 @@ from pyrogram.api.errors import (
|
|||||||
from pyrogram.client.handlers import DisconnectHandler
|
from pyrogram.client.handlers import DisconnectHandler
|
||||||
from pyrogram.client.handlers.handler import Handler
|
from pyrogram.client.handlers.handler import Handler
|
||||||
from pyrogram.client.methods.password.utils import compute_check
|
from pyrogram.client.methods.password.utils import compute_check
|
||||||
|
from pyrogram.client.session_storage import BaseSessionConfig
|
||||||
from pyrogram.crypto import AES
|
from pyrogram.crypto import AES
|
||||||
from pyrogram.session import Auth, Session
|
from pyrogram.session import Auth, Session
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
from .ext import utils, Syncer, BaseClient
|
from .ext import utils, Syncer, BaseClient
|
||||||
from .methods import Methods
|
from .methods import Methods
|
||||||
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
|
from .session_storage import SessionDoesNotExist
|
||||||
|
from .session_storage.json_session_storage import JsonSessionStorage
|
||||||
|
from .session_storage.string_session_storage import StringSessionStorage
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
session_name: str,
|
session_name: Union[str, BaseSessionConfig],
|
||||||
api_id: Union[int, str] = None,
|
api_id: Union[int, str] = None,
|
||||||
api_hash: str = None,
|
api_hash: str = None,
|
||||||
app_version: str = None,
|
app_version: str = None,
|
||||||
@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
|
|||||||
config_file: str = BaseClient.CONFIG_FILE,
|
config_file: str = BaseClient.CONFIG_FILE,
|
||||||
plugins: dict = None,
|
plugins: dict = None,
|
||||||
no_updates: bool = None,
|
no_updates: bool = None,
|
||||||
takeout: bool = None,
|
takeout: bool = None):
|
||||||
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
|
|
||||||
super().__init__(session_storage_cls(self))
|
|
||||||
|
|
||||||
self.session_name = session_name
|
if isinstance(session_name, str):
|
||||||
|
if session_name.startswith(':'):
|
||||||
|
session_storage = StringSessionStorage(self, session_name)
|
||||||
|
else:
|
||||||
|
session_storage = JsonSessionStorage(self, session_name)
|
||||||
|
elif isinstance(session_name, BaseSessionConfig):
|
||||||
|
session_storage = session_name.session_storage_cls(self, session_name)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
|
||||||
|
|
||||||
|
super().__init__(session_storage)
|
||||||
|
|
||||||
|
self.session_name = str(session_name) # TODO: build correct session name
|
||||||
self.api_id = int(api_id) if api_id else None
|
self.api_id = int(api_id) if api_id else None
|
||||||
self.api_hash = api_hash
|
self.api_hash = api_hash
|
||||||
self.app_version = app_version
|
self.app_version = app_version
|
||||||
@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
def load_session(self):
|
def load_session(self):
|
||||||
try:
|
try:
|
||||||
self.session_storage.load_session(self.session_name)
|
self.session_storage.load_session()
|
||||||
except SessionDoesNotExist:
|
except SessionDoesNotExist:
|
||||||
session_name = self.session_name[:32]
|
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
||||||
if session_name != self.session_name:
|
|
||||||
session_name += '...'
|
|
||||||
log.info('Could not load session "{}", initializing new one'.format(self.session_name))
|
|
||||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||||
|
|
||||||
def load_plugins(self):
|
def load_plugins(self):
|
||||||
@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
|
|||||||
log.warning('No plugin loaded from "{}"'.format(root))
|
log.warning('No plugin loaded from "{}"'.format(root))
|
||||||
|
|
||||||
def save_session(self):
|
def save_session(self):
|
||||||
self.session_storage.save_session(self.session_name)
|
self.session_storage.save_session()
|
||||||
|
|
||||||
def get_initial_dialogs_chunk(self,
|
def get_initial_dialogs_chunk(self,
|
||||||
offset_date: int = 0):
|
offset_date: int = 0):
|
||||||
|
@ -83,10 +83,10 @@ class Syncer:
|
|||||||
def sync(cls, client):
|
def sync(cls, client):
|
||||||
client.date = int(time.time())
|
client.date = int(time.time())
|
||||||
try:
|
try:
|
||||||
client.session_storage.save_session(client.session_name, sync=True)
|
client.session_storage.save_session(sync=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.critical(e, exc_info=True)
|
log.critical(e, exc_info=True)
|
||||||
else:
|
else:
|
||||||
log.info("Synced {}".format(client.session_name))
|
log.info("Synced {}".format(client.session_name))
|
||||||
finally:
|
finally:
|
||||||
client.session_storage.sync_cleanup(client.session_name)
|
client.session_storage.sync_cleanup()
|
||||||
|
@ -17,6 +17,4 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from .session_storage_mixin import SessionStorageMixin
|
from .session_storage_mixin import SessionStorageMixin
|
||||||
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
|
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
|
||||||
from .json_session_storage import JsonSessionStorage
|
|
||||||
from .string_session_storage import StringSessionStorage
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
|
|
||||||
@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class BaseSessionStorage(abc.ABC):
|
class BaseSessionStorage(abc.ABC):
|
||||||
def __init__(self, client: 'pyrogram.client.BaseClient'):
|
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
self.session_data = session_data
|
||||||
self.dc_id = 1
|
self.dc_id = 1
|
||||||
self.test_mode = None
|
self.test_mode = None
|
||||||
self.auth_key = None
|
self.auth_key = None
|
||||||
@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC):
|
|||||||
self.peers_by_phone = {}
|
self.peers_by_phone = {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def load_session(self, name: str):
|
def load_session(self):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def save_session(self, name: str, sync=False):
|
def save_session(self, sync=False):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def sync_cleanup(self, name: str):
|
def sync_cleanup(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSessionConfig(abc.ABC):
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def session_storage_cls(self) -> Type[BaseSessionStorage]:
|
||||||
...
|
...
|
||||||
|
@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage):
|
|||||||
name += '.session'
|
name += '.session'
|
||||||
return os.path.join(self.client.workdir, name)
|
return os.path.join(self.client.workdir, name)
|
||||||
|
|
||||||
def load_session(self, name: str):
|
def load_session(self):
|
||||||
file_path = self._get_file_name(name)
|
file_path = self._get_file_name(self.session_data)
|
||||||
log.info('Loading JSON session from {}'.format(file_path))
|
log.info('Loading JSON session from {}'.format(file_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage):
|
|||||||
if peer:
|
if peer:
|
||||||
self.peers_by_phone[k] = peer
|
self.peers_by_phone[k] = peer
|
||||||
|
|
||||||
def save_session(self, name: str, sync=False):
|
def save_session(self, sync=False):
|
||||||
file_path = self._get_file_name(name)
|
file_path = self._get_file_name(self.session_data)
|
||||||
|
|
||||||
if sync:
|
if sync:
|
||||||
file_path += '.tmp'
|
file_path += '.tmp'
|
||||||
@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage):
|
|||||||
|
|
||||||
# execution won't be here if an error has occurred earlier
|
# execution won't be here if an error has occurred earlier
|
||||||
if sync:
|
if sync:
|
||||||
shutil.move(file_path, self._get_file_name(name))
|
shutil.move(file_path, self._get_file_name(self.session_data))
|
||||||
|
|
||||||
def sync_cleanup(self, name: str):
|
def sync_cleanup(self):
|
||||||
try:
|
try:
|
||||||
os.remove(self._get_file_name(name) + '.tmp')
|
os.remove(self._get_file_name(self.session_data) + '.tmp')
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
@ -5,8 +5,7 @@ import struct
|
|||||||
from . import BaseSessionStorage, SessionDoesNotExist
|
from . import BaseSessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
def StringSessionStorage(print_session: bool = False):
|
class StringSessionStorage(BaseSessionStorage):
|
||||||
class StringSessionStorageClass(BaseSessionStorage):
|
|
||||||
"""
|
"""
|
||||||
Packs session data as following (forcing little-endian byte order):
|
Packs session data as following (forcing little-endian byte order):
|
||||||
Char dc_id (1 byte, unsigned)
|
Char dc_id (1 byte, unsigned)
|
||||||
@ -18,21 +17,21 @@ def StringSessionStorage(print_session: bool = False):
|
|||||||
"""
|
"""
|
||||||
PACK_FORMAT = '<B?q256s'
|
PACK_FORMAT = '<B?q256s'
|
||||||
|
|
||||||
def load_session(self, session_string: str):
|
def load_session(self):
|
||||||
try:
|
try:
|
||||||
decoded = base64.b64decode(session_string)
|
session_string = self.session_data[1:]
|
||||||
|
session_string += '=' * (4 - len(session_string) % 4) # restore padding
|
||||||
|
decoded = base64.b64decode(session_string, b'-_')
|
||||||
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
|
self.dc_id, self.test_mode, self.user_id, self.auth_key = struct.unpack(self.PACK_FORMAT, decoded)
|
||||||
except (struct.error, binascii.Error):
|
except (struct.error, binascii.Error):
|
||||||
raise SessionDoesNotExist()
|
raise SessionDoesNotExist()
|
||||||
|
|
||||||
def save_session(self, session_string: str, sync=False):
|
def save_session(self, sync=False):
|
||||||
if print_session and not sync:
|
if not sync:
|
||||||
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
|
packed = struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.auth_key)
|
||||||
encoded = base64.b64encode(packed).decode('latin-1')
|
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
|
||||||
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
|
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
|
||||||
print('Created session string:\n{}'.format(split))
|
print('Created session string:\n{}'.format(split))
|
||||||
|
|
||||||
def sync_cleanup(self, session_string: str):
|
def sync_cleanup(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return StringSessionStorageClass
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user