From 9d32b28f94e4c0f263e80e758fa530a3471cf91b Mon Sep 17 00:00:00 2001 From: bakatrouble Date: Thu, 21 Feb 2019 20:12:11 +0300 Subject: [PATCH] Implement extendable session storage and JSON session storage --- pyrogram/client/client.py | 59 ++------- pyrogram/client/ext/base_client.py | 14 +-- pyrogram/client/ext/syncer.py | 41 +------ pyrogram/client/session_storage/__init__.py | 21 ++++ .../session_storage/base_session_storage.py | 50 ++++++++ .../session_storage/json_session_storage.py | 116 ++++++++++++++++++ .../session_storage/session_storage_mixin.py | 73 +++++++++++ 7 files changed, 278 insertions(+), 96 deletions(-) create mode 100644 pyrogram/client/session_storage/__init__.py create mode 100644 pyrogram/client/session_storage/base_session_storage.py create mode 100644 pyrogram/client/session_storage/json_session_storage.py create mode 100644 pyrogram/client/session_storage/session_storage_mixin.py diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index f62c046c..9a9f8482 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -36,7 +36,7 @@ from importlib import import_module from pathlib import Path from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Thread -from typing import Union, List +from typing import Union, List, Type from pyrogram.api import functions, types from pyrogram.api.core import Object @@ -56,6 +56,7 @@ from pyrogram.session import Auth, Session from .dispatcher import Dispatcher from .ext import utils, Syncer, BaseClient from .methods import Methods +from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist log = logging.getLogger(__name__) @@ -199,8 +200,9 @@ class Client(Methods, BaseClient): config_file: str = BaseClient.CONFIG_FILE, plugins: dict = None, no_updates: bool = None, - takeout: bool = None): - super().__init__() + takeout: bool = None, + session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage): + super().__init__(session_storage_cls(self)) self.session_name = session_name self.api_id = int(api_id) if api_id else None @@ -296,8 +298,8 @@ class Client(Methods, BaseClient): now = time.time() if abs(now - self.date) > Client.OFFLINE_SLEEP: - self.peers_by_username = {} - self.peers_by_phone = {} + self.peers_by_username.clear() + self.peers_by_phone.clear() self.get_initial_dialogs() self.get_contacts() @@ -1101,33 +1103,10 @@ class Client(Methods, BaseClient): def load_session(self): try: - with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f: - s = json.load(f) - except FileNotFoundError: - self.dc_id = 1 - self.date = 0 + self.session_storage.load_session(self.session_name) + except SessionDoesNotExist: + log.info('Session {} was not found, initializing new one') self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create() - else: - self.dc_id = s["dc_id"] - self.test_mode = s["test_mode"] - self.auth_key = base64.b64decode("".join(s["auth_key"])) - self.user_id = s["user_id"] - self.date = s.get("date", 0) - - for k, v in s.get("peers_by_id", {}).items(): - self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v) - - for k, v in s.get("peers_by_username", {}).items(): - peer = self.peers_by_id.get(v, None) - - if peer: - self.peers_by_username[k] = peer - - for k, v in s.get("peers_by_phone", {}).items(): - peer = self.peers_by_id.get(v, None) - - if peer: - self.peers_by_phone[k] = peer def load_plugins(self): if self.plugins.get("enabled", False): @@ -1234,23 +1213,7 @@ class Client(Methods, BaseClient): log.warning('No plugin loaded from "{}"'.format(root)) def save_session(self): - auth_key = base64.b64encode(self.auth_key).decode() - auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] - - os.makedirs(self.workdir, exist_ok=True) - - with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), "w", encoding="utf-8") as f: - json.dump( - dict( - dc_id=self.dc_id, - test_mode=self.test_mode, - auth_key=auth_key, - user_id=self.user_id, - date=self.date - ), - f, - indent=4 - ) + self.session_storage.save_session(self.session_name) def get_initial_dialogs_chunk(self, offset_date: int = 0): diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index d2c348a8..87f11e23 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -24,9 +24,10 @@ from threading import Lock from pyrogram import __version__ from ..style import Markdown, HTML from ...session.internals import MsgId +from ..session_storage import SessionStorageMixin, BaseSessionStorage -class BaseClient: +class BaseClient(SessionStorageMixin): class StopTransmission(StopIteration): pass @@ -67,20 +68,13 @@ class BaseClient: 13: "video_note" } - def __init__(self): + def __init__(self, session_storage: BaseSessionStorage): + self.session_storage = session_storage self.bot_token = None - self.dc_id = None - self.auth_key = None - self.user_id = None - self.date = None self.rnd_id = MsgId self.channels_pts = {} - self.peers_by_id = {} - self.peers_by_username = {} - self.peers_by_phone = {} - self.markdown = Markdown(self.peers_by_id) self.html = HTML(self.peers_by_id) diff --git a/pyrogram/client/ext/syncer.py b/pyrogram/client/ext/syncer.py index e169d2a3..8930b13e 100644 --- a/pyrogram/client/ext/syncer.py +++ b/pyrogram/client/ext/syncer.py @@ -81,47 +81,12 @@ class Syncer: @classmethod def sync(cls, client): - temporary = os.path.join(client.workdir, "{}.sync".format(client.session_name)) - persistent = os.path.join(client.workdir, "{}.session".format(client.session_name)) - + client.date = int(time.time()) try: - auth_key = base64.b64encode(client.auth_key).decode() - auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] - - data = dict( - dc_id=client.dc_id, - test_mode=client.test_mode, - auth_key=auth_key, - user_id=client.user_id, - date=int(time.time()), - peers_by_id={ - k: getattr(v, "access_hash", None) - for k, v in client.peers_by_id.copy().items() - }, - peers_by_username={ - k: utils.get_peer_id(v) - for k, v in client.peers_by_username.copy().items() - }, - peers_by_phone={ - k: utils.get_peer_id(v) - for k, v in client.peers_by_phone.copy().items() - } - ) - - os.makedirs(client.workdir, exist_ok=True) - - with open(temporary, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4) - - f.flush() - os.fsync(f.fileno()) + client.session_storage.save_session(client.session_name, sync=True) except Exception as e: log.critical(e, exc_info=True) else: - shutil.move(temporary, persistent) log.info("Synced {}".format(client.session_name)) finally: - try: - os.remove(temporary) - except OSError: - pass + client.session_storage.sync_cleanup(client.session_name) diff --git a/pyrogram/client/session_storage/__init__.py b/pyrogram/client/session_storage/__init__.py new file mode 100644 index 00000000..6ee92ebc --- /dev/null +++ b/pyrogram/client/session_storage/__init__.py @@ -0,0 +1,21 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2019 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +from .session_storage_mixin import SessionStorageMixin +from .base_session_storage import BaseSessionStorage, SessionDoesNotExist +from .json_session_storage import JsonSessionStorage diff --git a/pyrogram/client/session_storage/base_session_storage.py b/pyrogram/client/session_storage/base_session_storage.py new file mode 100644 index 00000000..75e416b4 --- /dev/null +++ b/pyrogram/client/session_storage/base_session_storage.py @@ -0,0 +1,50 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2019 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +import abc + +import pyrogram + + +class SessionDoesNotExist(Exception): + pass + + +class BaseSessionStorage(abc.ABC): + def __init__(self, client: 'pyrogram.client.BaseClient'): + self.client = client + self.dc_id = 1 + self.test_mode = None + self.auth_key = None + self.user_id = None + self.date = 0 + self.peers_by_id = {} + self.peers_by_username = {} + self.peers_by_phone = {} + + @abc.abstractmethod + def load_session(self, name: str): + ... + + @abc.abstractmethod + def save_session(self, name: str, sync=False): + ... + + @abc.abstractmethod + def sync_cleanup(self, name: str): + ... diff --git a/pyrogram/client/session_storage/json_session_storage.py b/pyrogram/client/session_storage/json_session_storage.py new file mode 100644 index 00000000..679a21f3 --- /dev/null +++ b/pyrogram/client/session_storage/json_session_storage.py @@ -0,0 +1,116 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2019 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +import base64 +import json +import logging +import os +import shutil + +from ..ext import utils +from . import BaseSessionStorage, SessionDoesNotExist + + +log = logging.getLogger(__name__) + + +class JsonSessionStorage(BaseSessionStorage): + def _get_file_name(self, name: str): + if not name.endswith('.session'): + name += '.session' + return os.path.join(self.client.workdir, name) + + def load_session(self, name: str): + file_path = self._get_file_name(name) + log.info('Loading JSON session from {}'.format(file_path)) + + try: + with open(file_path, encoding='utf-8') as f: + s = json.load(f) + except FileNotFoundError: + raise SessionDoesNotExist() + + self.dc_id = s["dc_id"] + self.test_mode = s["test_mode"] + self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key + self.user_id = s["user_id"] + self.date = s.get("date", 0) + + for k, v in s.get("peers_by_id", {}).items(): + self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v) + + for k, v in s.get("peers_by_username", {}).items(): + peer = self.peers_by_id.get(v, None) + + if peer: + self.peers_by_username[k] = peer + + for k, v in s.get("peers_by_phone", {}).items(): + peer = self.peers_by_id.get(v, None) + + if peer: + self.peers_by_phone[k] = peer + + def save_session(self, name: str, sync=False): + file_path = self._get_file_name(name) + + if sync: + file_path += '.tmp' + + log.info('Saving JSON session to {}, sync={}'.format(file_path, sync)) + + auth_key = base64.b64encode(self.auth_key).decode() + auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars + + os.makedirs(self.client.workdir, exist_ok=True) + + data = { + 'dc_id': self.dc_id, + 'test_mode': self.test_mode, + 'auth_key': auth_key, + 'user_id': self.user_id, + 'date': self.date, + 'peers_by_id': { + k: getattr(v, "access_hash", None) + for k, v in self.peers_by_id.copy().items() + }, + 'peers_by_username': { + k: utils.get_peer_id(v) + for k, v in self.peers_by_username.copy().items() + }, + 'peers_by_phone': { + k: utils.get_peer_id(v) + for k, v in self.peers_by_phone.copy().items() + } + } + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + + f.flush() + os.fsync(f.fileno()) + + # execution won't be here if an error has occurred earlier + if sync: + shutil.move(file_path, self._get_file_name(name)) + + def sync_cleanup(self, name: str): + try: + os.remove(self._get_file_name(name) + '.tmp') + except OSError: + pass diff --git a/pyrogram/client/session_storage/session_storage_mixin.py b/pyrogram/client/session_storage/session_storage_mixin.py new file mode 100644 index 00000000..bfe9a590 --- /dev/null +++ b/pyrogram/client/session_storage/session_storage_mixin.py @@ -0,0 +1,73 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2019 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +from typing import Dict + + +class SessionStorageMixin: + @property + def dc_id(self) -> int: + return self.session_storage.dc_id + + @dc_id.setter + def dc_id(self, val): + self.session_storage.dc_id = val + + @property + def test_mode(self) -> bool: + return self.session_storage.test_mode + + @test_mode.setter + def test_mode(self, val): + self.session_storage.test_mode = val + + @property + def auth_key(self) -> bytes: + return self.session_storage.auth_key + + @auth_key.setter + def auth_key(self, val): + self.session_storage.auth_key = val + + @property + def user_id(self): + return self.session_storage.user_id + + @user_id.setter + def user_id(self, val) -> int: + self.session_storage.user_id = val + + @property + def date(self) -> int: + return self.session_storage.date + + @date.setter + def date(self, val): + self.session_storage.date = val + + @property + def peers_by_id(self) -> Dict[str, int]: + return self.session_storage.peers_by_id + + @property + def peers_by_username(self) -> Dict[str, int]: + return self.session_storage.peers_by_username + + @property + def peers_by_phone(self) -> Dict[str, int]: + return self.session_storage.peers_by_phone