mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-30 13:57:54 +00:00
Implement extendable session storage and JSON session storage
This commit is contained in:
@@ -36,7 +36,7 @@ from importlib import import_module
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Union, List
|
from typing import Union, List, Type
|
||||||
|
|
||||||
from pyrogram.api import functions, types
|
from pyrogram.api import functions, types
|
||||||
from pyrogram.api.core import Object
|
from pyrogram.api.core import Object
|
||||||
@@ -56,6 +56,7 @@ from pyrogram.session import Auth, Session
|
|||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
from .ext import utils, Syncer, BaseClient
|
from .ext import utils, Syncer, BaseClient
|
||||||
from .methods import Methods
|
from .methods import Methods
|
||||||
|
from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -199,8 +200,9 @@ class Client(Methods, BaseClient):
|
|||||||
config_file: str = BaseClient.CONFIG_FILE,
|
config_file: str = BaseClient.CONFIG_FILE,
|
||||||
plugins: dict = None,
|
plugins: dict = None,
|
||||||
no_updates: bool = None,
|
no_updates: bool = None,
|
||||||
takeout: bool = None):
|
takeout: bool = None,
|
||||||
super().__init__()
|
session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
|
||||||
|
super().__init__(session_storage_cls(self))
|
||||||
|
|
||||||
self.session_name = session_name
|
self.session_name = session_name
|
||||||
self.api_id = int(api_id) if api_id else None
|
self.api_id = int(api_id) if api_id else None
|
||||||
@@ -296,8 +298,8 @@ class Client(Methods, BaseClient):
|
|||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if abs(now - self.date) > Client.OFFLINE_SLEEP:
|
if abs(now - self.date) > Client.OFFLINE_SLEEP:
|
||||||
self.peers_by_username = {}
|
self.peers_by_username.clear()
|
||||||
self.peers_by_phone = {}
|
self.peers_by_phone.clear()
|
||||||
|
|
||||||
self.get_initial_dialogs()
|
self.get_initial_dialogs()
|
||||||
self.get_contacts()
|
self.get_contacts()
|
||||||
@@ -1101,33 +1103,10 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
def load_session(self):
|
def load_session(self):
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f:
|
self.session_storage.load_session(self.session_name)
|
||||||
s = json.load(f)
|
except SessionDoesNotExist:
|
||||||
except FileNotFoundError:
|
log.info('Session {} was not found, initializing new one')
|
||||||
self.dc_id = 1
|
|
||||||
self.date = 0
|
|
||||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||||
else:
|
|
||||||
self.dc_id = s["dc_id"]
|
|
||||||
self.test_mode = s["test_mode"]
|
|
||||||
self.auth_key = base64.b64decode("".join(s["auth_key"]))
|
|
||||||
self.user_id = s["user_id"]
|
|
||||||
self.date = s.get("date", 0)
|
|
||||||
|
|
||||||
for k, v in s.get("peers_by_id", {}).items():
|
|
||||||
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
|
||||||
|
|
||||||
for k, v in s.get("peers_by_username", {}).items():
|
|
||||||
peer = self.peers_by_id.get(v, None)
|
|
||||||
|
|
||||||
if peer:
|
|
||||||
self.peers_by_username[k] = peer
|
|
||||||
|
|
||||||
for k, v in s.get("peers_by_phone", {}).items():
|
|
||||||
peer = self.peers_by_id.get(v, None)
|
|
||||||
|
|
||||||
if peer:
|
|
||||||
self.peers_by_phone[k] = peer
|
|
||||||
|
|
||||||
def load_plugins(self):
|
def load_plugins(self):
|
||||||
if self.plugins.get("enabled", False):
|
if self.plugins.get("enabled", False):
|
||||||
@@ -1234,23 +1213,7 @@ class Client(Methods, BaseClient):
|
|||||||
log.warning('No plugin loaded from "{}"'.format(root))
|
log.warning('No plugin loaded from "{}"'.format(root))
|
||||||
|
|
||||||
def save_session(self):
|
def save_session(self):
|
||||||
auth_key = base64.b64encode(self.auth_key).decode()
|
self.session_storage.save_session(self.session_name)
|
||||||
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
|
|
||||||
|
|
||||||
os.makedirs(self.workdir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), "w", encoding="utf-8") as f:
|
|
||||||
json.dump(
|
|
||||||
dict(
|
|
||||||
dc_id=self.dc_id,
|
|
||||||
test_mode=self.test_mode,
|
|
||||||
auth_key=auth_key,
|
|
||||||
user_id=self.user_id,
|
|
||||||
date=self.date
|
|
||||||
),
|
|
||||||
f,
|
|
||||||
indent=4
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_initial_dialogs_chunk(self,
|
def get_initial_dialogs_chunk(self,
|
||||||
offset_date: int = 0):
|
offset_date: int = 0):
|
||||||
|
@@ -24,9 +24,10 @@ from threading import Lock
|
|||||||
from pyrogram import __version__
|
from pyrogram import __version__
|
||||||
from ..style import Markdown, HTML
|
from ..style import Markdown, HTML
|
||||||
from ...session.internals import MsgId
|
from ...session.internals import MsgId
|
||||||
|
from ..session_storage import SessionStorageMixin, BaseSessionStorage
|
||||||
|
|
||||||
|
|
||||||
class BaseClient:
|
class BaseClient(SessionStorageMixin):
|
||||||
class StopTransmission(StopIteration):
|
class StopTransmission(StopIteration):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -67,20 +68,13 @@ class BaseClient:
|
|||||||
13: "video_note"
|
13: "video_note"
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, session_storage: BaseSessionStorage):
|
||||||
|
self.session_storage = session_storage
|
||||||
self.bot_token = None
|
self.bot_token = None
|
||||||
self.dc_id = None
|
|
||||||
self.auth_key = None
|
|
||||||
self.user_id = None
|
|
||||||
self.date = None
|
|
||||||
|
|
||||||
self.rnd_id = MsgId
|
self.rnd_id = MsgId
|
||||||
self.channels_pts = {}
|
self.channels_pts = {}
|
||||||
|
|
||||||
self.peers_by_id = {}
|
|
||||||
self.peers_by_username = {}
|
|
||||||
self.peers_by_phone = {}
|
|
||||||
|
|
||||||
self.markdown = Markdown(self.peers_by_id)
|
self.markdown = Markdown(self.peers_by_id)
|
||||||
self.html = HTML(self.peers_by_id)
|
self.html = HTML(self.peers_by_id)
|
||||||
|
|
||||||
|
@@ -81,47 +81,12 @@ class Syncer:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync(cls, client):
|
def sync(cls, client):
|
||||||
temporary = os.path.join(client.workdir, "{}.sync".format(client.session_name))
|
client.date = int(time.time())
|
||||||
persistent = os.path.join(client.workdir, "{}.session".format(client.session_name))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_key = base64.b64encode(client.auth_key).decode()
|
client.session_storage.save_session(client.session_name, sync=True)
|
||||||
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
|
|
||||||
|
|
||||||
data = dict(
|
|
||||||
dc_id=client.dc_id,
|
|
||||||
test_mode=client.test_mode,
|
|
||||||
auth_key=auth_key,
|
|
||||||
user_id=client.user_id,
|
|
||||||
date=int(time.time()),
|
|
||||||
peers_by_id={
|
|
||||||
k: getattr(v, "access_hash", None)
|
|
||||||
for k, v in client.peers_by_id.copy().items()
|
|
||||||
},
|
|
||||||
peers_by_username={
|
|
||||||
k: utils.get_peer_id(v)
|
|
||||||
for k, v in client.peers_by_username.copy().items()
|
|
||||||
},
|
|
||||||
peers_by_phone={
|
|
||||||
k: utils.get_peer_id(v)
|
|
||||||
for k, v in client.peers_by_phone.copy().items()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
os.makedirs(client.workdir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(temporary, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(data, f, indent=4)
|
|
||||||
|
|
||||||
f.flush()
|
|
||||||
os.fsync(f.fileno())
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.critical(e, exc_info=True)
|
log.critical(e, exc_info=True)
|
||||||
else:
|
else:
|
||||||
shutil.move(temporary, persistent)
|
|
||||||
log.info("Synced {}".format(client.session_name))
|
log.info("Synced {}".format(client.session_name))
|
||||||
finally:
|
finally:
|
||||||
try:
|
client.session_storage.sync_cleanup(client.session_name)
|
||||||
os.remove(temporary)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
21
pyrogram/client/session_storage/__init__.py
Normal file
21
pyrogram/client/session_storage/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||||
|
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||||
|
#
|
||||||
|
# This file is part of Pyrogram.
|
||||||
|
#
|
||||||
|
# Pyrogram is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Lesser General Public License as published
|
||||||
|
# by the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# Pyrogram is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Lesser General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from .session_storage_mixin import SessionStorageMixin
|
||||||
|
from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
|
||||||
|
from .json_session_storage import JsonSessionStorage
|
50
pyrogram/client/session_storage/base_session_storage.py
Normal file
50
pyrogram/client/session_storage/base_session_storage.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||||
|
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||||
|
#
|
||||||
|
# This file is part of Pyrogram.
|
||||||
|
#
|
||||||
|
# Pyrogram is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Lesser General Public License as published
|
||||||
|
# by the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# Pyrogram is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Lesser General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
import pyrogram
|
||||||
|
|
||||||
|
|
||||||
|
class SessionDoesNotExist(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSessionStorage(abc.ABC):
|
||||||
|
def __init__(self, client: 'pyrogram.client.BaseClient'):
|
||||||
|
self.client = client
|
||||||
|
self.dc_id = 1
|
||||||
|
self.test_mode = None
|
||||||
|
self.auth_key = None
|
||||||
|
self.user_id = None
|
||||||
|
self.date = 0
|
||||||
|
self.peers_by_id = {}
|
||||||
|
self.peers_by_username = {}
|
||||||
|
self.peers_by_phone = {}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_session(self, name: str):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def save_session(self, name: str, sync=False):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sync_cleanup(self, name: str):
|
||||||
|
...
|
116
pyrogram/client/session_storage/json_session_storage.py
Normal file
116
pyrogram/client/session_storage/json_session_storage.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||||
|
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||||
|
#
|
||||||
|
# This file is part of Pyrogram.
|
||||||
|
#
|
||||||
|
# Pyrogram is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Lesser General Public License as published
|
||||||
|
# by the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# Pyrogram is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Lesser General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from ..ext import utils
|
||||||
|
from . import BaseSessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSessionStorage(BaseSessionStorage):
|
||||||
|
def _get_file_name(self, name: str):
|
||||||
|
if not name.endswith('.session'):
|
||||||
|
name += '.session'
|
||||||
|
return os.path.join(self.client.workdir, name)
|
||||||
|
|
||||||
|
def load_session(self, name: str):
|
||||||
|
file_path = self._get_file_name(name)
|
||||||
|
log.info('Loading JSON session from {}'.format(file_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, encoding='utf-8') as f:
|
||||||
|
s = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise SessionDoesNotExist()
|
||||||
|
|
||||||
|
self.dc_id = s["dc_id"]
|
||||||
|
self.test_mode = s["test_mode"]
|
||||||
|
self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
|
||||||
|
self.user_id = s["user_id"]
|
||||||
|
self.date = s.get("date", 0)
|
||||||
|
|
||||||
|
for k, v in s.get("peers_by_id", {}).items():
|
||||||
|
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
||||||
|
|
||||||
|
for k, v in s.get("peers_by_username", {}).items():
|
||||||
|
peer = self.peers_by_id.get(v, None)
|
||||||
|
|
||||||
|
if peer:
|
||||||
|
self.peers_by_username[k] = peer
|
||||||
|
|
||||||
|
for k, v in s.get("peers_by_phone", {}).items():
|
||||||
|
peer = self.peers_by_id.get(v, None)
|
||||||
|
|
||||||
|
if peer:
|
||||||
|
self.peers_by_phone[k] = peer
|
||||||
|
|
||||||
|
def save_session(self, name: str, sync=False):
|
||||||
|
file_path = self._get_file_name(name)
|
||||||
|
|
||||||
|
if sync:
|
||||||
|
file_path += '.tmp'
|
||||||
|
|
||||||
|
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
|
||||||
|
|
||||||
|
auth_key = base64.b64encode(self.auth_key).decode()
|
||||||
|
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
|
||||||
|
|
||||||
|
os.makedirs(self.client.workdir, exist_ok=True)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'dc_id': self.dc_id,
|
||||||
|
'test_mode': self.test_mode,
|
||||||
|
'auth_key': auth_key,
|
||||||
|
'user_id': self.user_id,
|
||||||
|
'date': self.date,
|
||||||
|
'peers_by_id': {
|
||||||
|
k: getattr(v, "access_hash", None)
|
||||||
|
for k, v in self.peers_by_id.copy().items()
|
||||||
|
},
|
||||||
|
'peers_by_username': {
|
||||||
|
k: utils.get_peer_id(v)
|
||||||
|
for k, v in self.peers_by_username.copy().items()
|
||||||
|
},
|
||||||
|
'peers_by_phone': {
|
||||||
|
k: utils.get_peer_id(v)
|
||||||
|
for k, v in self.peers_by_phone.copy().items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
|
|
||||||
|
# execution won't be here if an error has occurred earlier
|
||||||
|
if sync:
|
||||||
|
shutil.move(file_path, self._get_file_name(name))
|
||||||
|
|
||||||
|
def sync_cleanup(self, name: str):
|
||||||
|
try:
|
||||||
|
os.remove(self._get_file_name(name) + '.tmp')
|
||||||
|
except OSError:
|
||||||
|
pass
|
73
pyrogram/client/session_storage/session_storage_mixin.py
Normal file
73
pyrogram/client/session_storage/session_storage_mixin.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||||
|
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||||
|
#
|
||||||
|
# This file is part of Pyrogram.
|
||||||
|
#
|
||||||
|
# Pyrogram is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Lesser General Public License as published
|
||||||
|
# by the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# Pyrogram is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Lesser General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStorageMixin:
|
||||||
|
@property
|
||||||
|
def dc_id(self) -> int:
|
||||||
|
return self.session_storage.dc_id
|
||||||
|
|
||||||
|
@dc_id.setter
|
||||||
|
def dc_id(self, val):
|
||||||
|
self.session_storage.dc_id = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_mode(self) -> bool:
|
||||||
|
return self.session_storage.test_mode
|
||||||
|
|
||||||
|
@test_mode.setter
|
||||||
|
def test_mode(self, val):
|
||||||
|
self.session_storage.test_mode = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_key(self) -> bytes:
|
||||||
|
return self.session_storage.auth_key
|
||||||
|
|
||||||
|
@auth_key.setter
|
||||||
|
def auth_key(self, val):
|
||||||
|
self.session_storage.auth_key = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self):
|
||||||
|
return self.session_storage.user_id
|
||||||
|
|
||||||
|
@user_id.setter
|
||||||
|
def user_id(self, val) -> int:
|
||||||
|
self.session_storage.user_id = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def date(self) -> int:
|
||||||
|
return self.session_storage.date
|
||||||
|
|
||||||
|
@date.setter
|
||||||
|
def date(self, val):
|
||||||
|
self.session_storage.date = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_id(self) -> Dict[str, int]:
|
||||||
|
return self.session_storage.peers_by_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_username(self) -> Dict[str, int]:
|
||||||
|
return self.session_storage.peers_by_username
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_phone(self) -> Dict[str, int]:
|
||||||
|
return self.session_storage.peers_by_phone
|
Reference in New Issue
Block a user