2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

Implement extendable session storage and JSON session storage

This commit is contained in:
bakatrouble 2019-02-21 20:12:11 +03:00
parent 567e9611df
commit 9d32b28f94
7 changed files with 278 additions and 96 deletions

View File

@ -36,7 +36,7 @@ from importlib import import_module
from pathlib import Path
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Thread
from typing import Union, List
from typing import Union, List, Type
from pyrogram.api import functions, types
from pyrogram.api.core import Object
@ -56,6 +56,7 @@ from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient
from .methods import Methods
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
log = logging.getLogger(__name__)
@ -199,8 +200,9 @@ class Client(Methods, BaseClient):
config_file: str = BaseClient.CONFIG_FILE,
plugins: dict = None,
no_updates: bool = None,
takeout: bool = None):
super().__init__()
takeout: bool = None,
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
super().__init__(session_storage_cls(self))
self.session_name = session_name
self.api_id = int(api_id) if api_id else None
@ -296,8 +298,8 @@ class Client(Methods, BaseClient):
now = time.time()
if abs(now - self.date) > Client.OFFLINE_SLEEP:
self.peers_by_username = {}
self.peers_by_phone = {}
self.peers_by_username.clear()
self.peers_by_phone.clear()
self.get_initial_dialogs()
self.get_contacts()
@ -1101,33 +1103,10 @@ class Client(Methods, BaseClient):
def load_session(self):
try:
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f:
s = json.load(f)
except FileNotFoundError:
self.dc_id = 1
self.date = 0
self.session_storage.load_session(self.session_name)
except SessionDoesNotExist:
log.info('Session {} was not found, initializing new one')
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
else:
self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"]
self.auth_key = base64.b64decode("".join(s["auth_key"]))
self.user_id = s["user_id"]
self.date = s.get("date", 0)
for k, v in s.get("peers_by_id", {}).items():
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
for k, v in s.get("peers_by_username", {}).items():
peer = self.peers_by_id.get(v, None)
if peer:
self.peers_by_username[k] = peer
for k, v in s.get("peers_by_phone", {}).items():
peer = self.peers_by_id.get(v, None)
if peer:
self.peers_by_phone[k] = peer
def load_plugins(self):
if self.plugins.get("enabled", False):
@ -1234,23 +1213,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self):
auth_key = base64.b64encode(self.auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
os.makedirs(self.workdir, exist_ok=True)
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), "w", encoding="utf-8") as f:
json.dump(
dict(
dc_id=self.dc_id,
test_mode=self.test_mode,
auth_key=auth_key,
user_id=self.user_id,
date=self.date
),
f,
indent=4
)
self.session_storage.save_session(self.session_name)
def get_initial_dialogs_chunk(self,
offset_date: int = 0):

View File

@ -24,9 +24,10 @@ from threading import Lock
from pyrogram import __version__
from ..style import Markdown, HTML
from ...session.internals import MsgId
from ..session_storage import SessionStorageMixin, BaseSessionStorage
class BaseClient:
class BaseClient(SessionStorageMixin):
class StopTransmission(StopIteration):
pass
@ -67,20 +68,13 @@ class BaseClient:
13: "video_note"
}
def __init__(self):
def __init__(self, session_storage: BaseSessionStorage):
self.session_storage = session_storage
self.bot_token = None
self.dc_id = None
self.auth_key = None
self.user_id = None
self.date = None
self.rnd_id = MsgId
self.channels_pts = {}
self.peers_by_id = {}
self.peers_by_username = {}
self.peers_by_phone = {}
self.markdown = Markdown(self.peers_by_id)
self.html = HTML(self.peers_by_id)

View File

@ -81,47 +81,12 @@ class Syncer:
@classmethod
def sync(cls, client):
temporary = os.path.join(client.workdir, "{}.sync".format(client.session_name))
persistent = os.path.join(client.workdir, "{}.session".format(client.session_name))
client.date = int(time.time())
try:
auth_key = base64.b64encode(client.auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
data = dict(
dc_id=client.dc_id,
test_mode=client.test_mode,
auth_key=auth_key,
user_id=client.user_id,
date=int(time.time()),
peers_by_id={
k: getattr(v, "access_hash", None)
for k, v in client.peers_by_id.copy().items()
},
peers_by_username={
k: utils.get_peer_id(v)
for k, v in client.peers_by_username.copy().items()
},
peers_by_phone={
k: utils.get_peer_id(v)
for k, v in client.peers_by_phone.copy().items()
}
)
os.makedirs(client.workdir, exist_ok=True)
with open(temporary, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
f.flush()
os.fsync(f.fileno())
client.session_storage.save_session(client.session_name, sync=True)
except Exception as e:
log.critical(e, exc_info=True)
else:
shutil.move(temporary, persistent)
log.info("Synced {}".format(client.session_name))
finally:
try:
os.remove(temporary)
except OSError:
pass
client.session_storage.sync_cleanup(client.session_name)

View File

@ -0,0 +1,21 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .session_storage_mixin import SessionStorageMixin
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
from .json_session_storage import JsonSessionStorage

View File

@ -0,0 +1,50 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc
import pyrogram
class SessionDoesNotExist(Exception):
pass
class BaseSessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient'):
self.client = client
self.dc_id = 1
self.test_mode = None
self.auth_key = None
self.user_id = None
self.date = 0
self.peers_by_id = {}
self.peers_by_username = {}
self.peers_by_phone = {}
@abc.abstractmethod
def load_session(self, name: str):
...
@abc.abstractmethod
def save_session(self, name: str, sync=False):
...
@abc.abstractmethod
def sync_cleanup(self, name: str):
...

View File

@ -0,0 +1,116 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import base64
import json
import logging
import os
import shutil
from ..ext import utils
from . import BaseSessionStorage, SessionDoesNotExist
log = logging.getLogger(__name__)
class JsonSessionStorage(BaseSessionStorage):
def _get_file_name(self, name: str):
if not name.endswith('.session'):
name += '.session'
return os.path.join(self.client.workdir, name)
def load_session(self, name: str):
file_path = self._get_file_name(name)
log.info('Loading JSON session from {}'.format(file_path))
try:
with open(file_path, encoding='utf-8') as f:
s = json.load(f)
except FileNotFoundError:
raise SessionDoesNotExist()
self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"]
self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
self.user_id = s["user_id"]
self.date = s.get("date", 0)
for k, v in s.get("peers_by_id", {}).items():
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
for k, v in s.get("peers_by_username", {}).items():
peer = self.peers_by_id.get(v, None)
if peer:
self.peers_by_username[k] = peer
for k, v in s.get("peers_by_phone", {}).items():
peer = self.peers_by_id.get(v, None)
if peer:
self.peers_by_phone[k] = peer
def save_session(self, name: str, sync=False):
file_path = self._get_file_name(name)
if sync:
file_path += '.tmp'
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
auth_key = base64.b64encode(self.auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
os.makedirs(self.client.workdir, exist_ok=True)
data = {
'dc_id': self.dc_id,
'test_mode': self.test_mode,
'auth_key': auth_key,
'user_id': self.user_id,
'date': self.date,
'peers_by_id': {
k: getattr(v, "access_hash", None)
for k, v in self.peers_by_id.copy().items()
},
'peers_by_username': {
k: utils.get_peer_id(v)
for k, v in self.peers_by_username.copy().items()
},
'peers_by_phone': {
k: utils.get_peer_id(v)
for k, v in self.peers_by_phone.copy().items()
}
}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4)
f.flush()
os.fsync(f.fileno())
# execution won't be here if an error has occurred earlier
if sync:
shutil.move(file_path, self._get_file_name(name))
def sync_cleanup(self, name: str):
try:
os.remove(self._get_file_name(name) + '.tmp')
except OSError:
pass

View File

@ -0,0 +1,73 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Dict
class SessionStorageMixin:
@property
def dc_id(self) -> int:
return self.session_storage.dc_id
@dc_id.setter
def dc_id(self, val):
self.session_storage.dc_id = val
@property
def test_mode(self) -> bool:
return self.session_storage.test_mode
@test_mode.setter
def test_mode(self, val):
self.session_storage.test_mode = val
@property
def auth_key(self) -> bytes:
return self.session_storage.auth_key
@auth_key.setter
def auth_key(self, val):
self.session_storage.auth_key = val
@property
def user_id(self):
return self.session_storage.user_id
@user_id.setter
def user_id(self, val) -> int:
self.session_storage.user_id = val
@property
def date(self) -> int:
return self.session_storage.date
@date.setter
def date(self, val):
self.session_storage.date = val
@property
def peers_by_id(self) -> Dict[str, int]:
return self.session_storage.peers_by_id
@property
def peers_by_username(self) -> Dict[str, int]:
return self.session_storage.peers_by_username
@property
def peers_by_phone(self) -> Dict[str, int]:
return self.session_storage.peers_by_phone