2
0
mirror of https://github.com/pyrogram/pyrogram synced 2025-08-28 21:07:59 +00:00

add in-memory session storage, refactor session storages, remove mixin

This commit is contained in:
bakatrouble 2019-02-22 03:37:19 +03:00
parent 9c4e9e166e
commit 5dc33c6337
11 changed files with 267 additions and 188 deletions

View File

@ -50,15 +50,15 @@ from pyrogram.api.errors import (
from pyrogram.client.handlers import DisconnectHandler from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check from pyrogram.client.methods.password.utils import compute_check
from pyrogram.client.session_storage import BaseSessionConfig
from pyrogram.crypto import AES from pyrogram.crypto import AES
from pyrogram.session import Auth, Session from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient from .ext import utils, Syncer, BaseClient
from .methods import Methods from .methods import Methods
from .session_storage import SessionDoesNotExist from .session_storage import (
from .session_storage.json_session_storage import JsonSessionStorage SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage,
from .session_storage.string_session_storage import StringSessionStorage StringSessionStorage
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -183,7 +183,7 @@ class Client(Methods, BaseClient):
""" """
def __init__(self, def __init__(self,
session_name: Union[str, BaseSessionConfig], session_name: Union[str, SessionStorage],
api_id: Union[int, str] = None, api_id: Union[int, str] = None,
api_hash: str = None, api_hash: str = None,
app_version: str = None, app_version: str = None,
@ -209,14 +209,16 @@ class Client(Methods, BaseClient):
takeout: bool = None): takeout: bool = None):
if isinstance(session_name, str): if isinstance(session_name, str):
if session_name.startswith(':'): if session_name == ':memory:':
session_storage = MemorySessionStorage(self)
elif session_name.startswith(':'):
session_storage = StringSessionStorage(self, session_name) session_storage = StringSessionStorage(self, session_name)
else: else:
session_storage = JsonSessionStorage(self, session_name) session_storage = JsonSessionStorage(self, session_name)
elif isinstance(session_name, BaseSessionConfig): elif isinstance(session_name, SessionStorage):
session_storage = session_name.session_storage_cls(self, session_name) session_storage = session_name
else: else:
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass') raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass')
super().__init__(session_storage) super().__init__(session_storage)
@ -230,7 +232,7 @@ class Client(Methods, BaseClient):
self.ipv6 = ipv6 self.ipv6 = ipv6
# TODO: Make code consistent, use underscore for private/protected fields # TODO: Make code consistent, use underscore for private/protected fields
self._proxy = proxy self._proxy = proxy
self.test_mode = test_mode self.session_storage.test_mode = test_mode
self.phone_number = phone_number self.phone_number = phone_number
self.phone_code = phone_code self.phone_code = phone_code
self.password = password self.password = password
@ -282,10 +284,10 @@ class Client(Methods, BaseClient):
raise ConnectionError("Client has already been started") raise ConnectionError("Client has already been started")
if isinstance(self.session_storage, JsonSessionStorage): if isinstance(self.session_storage, JsonSessionStorage):
if self.BOT_TOKEN_RE.match(self.session_storage.session_data): if self.BOT_TOKEN_RE.match(self.session_storage._session_name):
self.is_bot = True self.session_storage.is_bot = True
self.bot_token = self.session_storage.session_data self.bot_token = self.session_storage._session_name
self.session_storage.session_data = self.session_storage.session_data.split(":")[0] self.session_storage._session_name = self.session_storage._session_name.split(":")[0]
warnings.warn('\nYou are using a bot token as session name.\n' warnings.warn('\nYou are using a bot token as session name.\n'
'It will be deprecated in next update, please use session file name to load ' 'It will be deprecated in next update, please use session file name to load '
'existing sessions and bot_token argument to create new sessions.', 'existing sessions and bot_token argument to create new sessions.',
@ -297,33 +299,33 @@ class Client(Methods, BaseClient):
self.session = Session( self.session = Session(
self, self,
self.dc_id, self.session_storage.dc_id,
self.auth_key self.session_storage.auth_key
) )
self.session.start() self.session.start()
self.is_started = True self.is_started = True
try: try:
if self.user_id is None: if self.session_storage.user_id is None:
if self.bot_token is None: if self.bot_token is None:
self.authorize_user() self.authorize_user()
else: else:
self.is_bot = True self.session_storage.is_bot = True
self.authorize_bot() self.authorize_bot()
self.save_session() self.save_session()
if not self.is_bot: if not self.session_storage.is_bot:
if self.takeout: if self.takeout:
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
log.warning("Takeout session {} initiated".format(self.takeout_id)) log.warning("Takeout session {} initiated".format(self.takeout_id))
now = time.time() now = time.time()
if abs(now - self.date) > Client.OFFLINE_SLEEP: if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
self.peers_by_username.clear() self.session_storage.peers_by_username.clear()
self.peers_by_phone.clear() self.session_storage.peers_by_phone.clear()
self.get_initial_dialogs() self.get_initial_dialogs()
self.get_contacts() self.get_contacts()
@ -512,19 +514,20 @@ class Client(Methods, BaseClient):
except UserMigrate as e: except UserMigrate as e:
self.session.stop() self.session.stop()
self.dc_id = e.x self.session_storage.dc_id = e.x
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create() self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
self.ipv6, self._proxy).create()
self.session = Session( self.session = Session(
self, self,
self.dc_id, self.session_storage.dc_id,
self.auth_key self.session_storage.auth_key
) )
self.session.start() self.session.start()
self.authorize_bot() self.authorize_bot()
else: else:
self.user_id = r.user.id self.session_storage.user_id = r.user.id
print("Logged in successfully as @{}".format(r.user.username)) print("Logged in successfully as @{}".format(r.user.username))
@ -564,19 +567,19 @@ class Client(Methods, BaseClient):
except (PhoneMigrate, NetworkMigrate) as e: except (PhoneMigrate, NetworkMigrate) as e:
self.session.stop() self.session.stop()
self.dc_id = e.x self.session_storage.dc_id = e.x
self.auth_key = Auth( self.session_storage.auth_key = Auth(
self.dc_id, self.session_storage.dc_id,
self.test_mode, self.session_storage.test_mode,
self.ipv6, self.ipv6,
self._proxy self._proxy
).create() ).create()
self.session = Session( self.session = Session(
self, self,
self.dc_id, self.session_storage.dc_id,
self.auth_key self.session_storage.auth_key
) )
self.session.start() self.session.start()
@ -752,7 +755,7 @@ class Client(Methods, BaseClient):
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id)) assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
self.password = None self.password = None
self.user_id = r.user.id self.session_storage.user_id = r.user.id
print("Logged in successfully as {}".format(r.user.first_name)) print("Logged in successfully as {}".format(r.user.first_name))
@ -776,13 +779,13 @@ class Client(Methods, BaseClient):
access_hash=access_hash access_hash=access_hash
) )
self.peers_by_id[user_id] = input_peer self.session_storage.peers_by_id[user_id] = input_peer
if username is not None: if username is not None:
self.peers_by_username[username.lower()] = input_peer self.session_storage.peers_by_username[username.lower()] = input_peer
if phone is not None: if phone is not None:
self.peers_by_phone[phone] = input_peer self.session_storage.peers_by_phone[phone] = input_peer
if isinstance(entity, (types.Chat, types.ChatForbidden)): if isinstance(entity, (types.Chat, types.ChatForbidden)):
chat_id = entity.id chat_id = entity.id
@ -792,7 +795,7 @@ class Client(Methods, BaseClient):
chat_id=chat_id chat_id=chat_id
) )
self.peers_by_id[peer_id] = input_peer self.session_storage.peers_by_id[peer_id] = input_peer
if isinstance(entity, (types.Channel, types.ChannelForbidden)): if isinstance(entity, (types.Channel, types.ChannelForbidden)):
channel_id = entity.id channel_id = entity.id
@ -810,10 +813,10 @@ class Client(Methods, BaseClient):
access_hash=access_hash access_hash=access_hash
) )
self.peers_by_id[peer_id] = input_peer self.session_storage.peers_by_id[peer_id] = input_peer
if username is not None: if username is not None:
self.peers_by_username[username.lower()] = input_peer self.session_storage.peers_by_username[username.lower()] = input_peer
def download_worker(self): def download_worker(self):
name = threading.current_thread().name name = threading.current_thread().name
@ -1127,10 +1130,11 @@ class Client(Methods, BaseClient):
def load_session(self): def load_session(self):
try: try:
self.session_storage.load_session() self.session_storage.load()
except SessionDoesNotExist: except SessionDoesNotExist:
log.info('Could not load session "{}", initiate new one'.format(self.session_name)) log.info('Could not load session "{}", initiate new one'.format(self.session_name))
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create() self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
self.ipv6, self._proxy).create()
def load_plugins(self): def load_plugins(self):
if self.plugins.get("enabled", False): if self.plugins.get("enabled", False):
@ -1237,7 +1241,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root)) log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self): def save_session(self):
self.session_storage.save_session() self.session_storage.save()
def get_initial_dialogs_chunk(self, def get_initial_dialogs_chunk(self,
offset_date: int = 0): offset_date: int = 0):
@ -1257,7 +1261,7 @@ class Client(Methods, BaseClient):
log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total peers: {}".format(len(self.peers_by_id))) log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
return r return r
def get_initial_dialogs(self): def get_initial_dialogs(self):
@ -1293,7 +1297,7 @@ class Client(Methods, BaseClient):
``KeyError`` in case the peer doesn't exist in the internal database. ``KeyError`` in case the peer doesn't exist in the internal database.
""" """
try: try:
return self.peers_by_id[peer_id] return self.session_storage.peers_by_id[peer_id]
except KeyError: except KeyError:
if type(peer_id) is str: if type(peer_id) is str:
if peer_id in ("self", "me"): if peer_id in ("self", "me"):
@ -1304,17 +1308,17 @@ class Client(Methods, BaseClient):
try: try:
int(peer_id) int(peer_id)
except ValueError: except ValueError:
if peer_id not in self.peers_by_username: if peer_id not in self.session_storage.peers_by_username:
self.send( self.send(
functions.contacts.ResolveUsername( functions.contacts.ResolveUsername(
username=peer_id username=peer_id
) )
) )
return self.peers_by_username[peer_id] return self.session_storage.peers_by_username[peer_id]
else: else:
try: try:
return self.peers_by_phone[peer_id] return self.session_storage.peers_by_phone[peer_id]
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
@ -1341,7 +1345,7 @@ class Client(Methods, BaseClient):
) )
try: try:
return self.peers_by_id[peer_id] return self.session_storage.peers_by_id[peer_id]
except KeyError: except KeyError:
raise PeerIdInvalid raise PeerIdInvalid
@ -1411,7 +1415,7 @@ class Client(Methods, BaseClient):
file_id = file_id or self.rnd_id() file_id = file_id or self.rnd_id()
md5_sum = md5() if not is_big and not is_missing_part else None md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self, self.dc_id, self.auth_key, is_media=True) session = Session(self, self.session_storage.dc_id, self.session_storage.auth_key, is_media=True)
session.start() session.start()
try: try:
@ -1492,7 +1496,7 @@ class Client(Methods, BaseClient):
session = self.media_sessions.get(dc_id, None) session = self.media_sessions.get(dc_id, None)
if session is None: if session is None:
if dc_id != self.dc_id: if dc_id != self.session_storage.dc_id:
exported_auth = self.send( exported_auth = self.send(
functions.auth.ExportAuthorization( functions.auth.ExportAuthorization(
dc_id=dc_id dc_id=dc_id
@ -1502,7 +1506,7 @@ class Client(Methods, BaseClient):
session = Session( session = Session(
self, self,
dc_id, dc_id,
Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(), Auth(dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
is_media=True is_media=True
) )
@ -1520,7 +1524,7 @@ class Client(Methods, BaseClient):
session = Session( session = Session(
self, self,
dc_id, dc_id,
self.auth_key, self.session_storage.auth_key,
is_media=True is_media=True
) )
@ -1588,7 +1592,7 @@ class Client(Methods, BaseClient):
cdn_session = Session( cdn_session = Session(
self, self,
r.dc_id, r.dc_id,
Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(), Auth(r.dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
is_media=True, is_media=True,
is_cdn=True is_cdn=True
) )

View File

@ -24,10 +24,10 @@ from threading import Lock
from pyrogram import __version__ from pyrogram import __version__
from ..style import Markdown, HTML from ..style import Markdown, HTML
from ...session.internals import MsgId from ...session.internals import MsgId
from ..session_storage import SessionStorageMixin, BaseSessionStorage from ..session_storage import SessionStorage
class BaseClient(SessionStorageMixin): class BaseClient:
class StopTransmission(StopIteration): class StopTransmission(StopIteration):
pass pass
@ -68,14 +68,14 @@ class BaseClient(SessionStorageMixin):
13: "video_note" 13: "video_note"
} }
def __init__(self, session_storage: BaseSessionStorage): def __init__(self, session_storage: SessionStorage):
self.session_storage = session_storage self.session_storage = session_storage
self.rnd_id = MsgId self.rnd_id = MsgId
self.channels_pts = {} self.channels_pts = {}
self.markdown = Markdown(self.peers_by_id) self.markdown = Markdown(self.session_storage.peers_by_id)
self.html = HTML(self.peers_by_id) self.html = HTML(self.session_storage.peers_by_id)
self.session = None self.session = None
self.media_sessions = {} self.media_sessions = {}

View File

@ -81,9 +81,9 @@ class Syncer:
@classmethod @classmethod
def sync(cls, client): def sync(cls, client):
client.date = int(time.time()) client.session_storage.date = int(time.time())
try: try:
client.session_storage.save_session(sync=True) client.session_storage.save(sync=True)
except Exception as e: except Exception as e:
log.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:

View File

@ -44,5 +44,5 @@ class GetContacts(BaseClient):
log.warning("get_contacts flood: waiting {} seconds".format(e.x)) log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x) time.sleep(e.x)
else: else:
log.info("Total contacts: {}".format(len(self.peers_by_phone))) log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
return [pyrogram.User._parse(self, user) for user in contacts.users] return [pyrogram.User._parse(self, user) for user in contacts.users]

View File

@ -16,5 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .session_storage_mixin import SessionStorageMixin from .abstract import SessionStorage, SessionDoesNotExist
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist from .memory import MemorySessionStorage
from .json import JsonSessionStorage
from .string import StringSessionStorage

View File

@ -16,66 +16,103 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from typing import Dict import abc
from typing import Type
import pyrogram
class SessionStorageMixin: class SessionDoesNotExist(Exception):
pass
class SessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient'):
self._client = client
@abc.abstractmethod
def load(self):
...
@abc.abstractmethod
def save(self, sync=False):
...
@abc.abstractmethod
def sync_cleanup(self):
...
@property @property
def dc_id(self) -> int: @abc.abstractmethod
return self.session_storage.dc_id def dc_id(self):
...
@dc_id.setter @dc_id.setter
@abc.abstractmethod
def dc_id(self, val): def dc_id(self, val):
self.session_storage.dc_id = val ...
@property @property
def test_mode(self) -> bool: @abc.abstractmethod
return self.session_storage.test_mode def test_mode(self):
...
@test_mode.setter @test_mode.setter
@abc.abstractmethod
def test_mode(self, val): def test_mode(self, val):
self.session_storage.test_mode = val ...
@property @property
def auth_key(self) -> bytes: @abc.abstractmethod
return self.session_storage.auth_key def auth_key(self):
...
@auth_key.setter @auth_key.setter
@abc.abstractmethod
def auth_key(self, val): def auth_key(self, val):
self.session_storage.auth_key = val ...
@property @property
@abc.abstractmethod
def user_id(self): def user_id(self):
return self.session_storage.user_id ...
@user_id.setter @user_id.setter
def user_id(self, val) -> int: @abc.abstractmethod
self.session_storage.user_id = val def user_id(self, val):
...
@property @property
def date(self) -> int: @abc.abstractmethod
return self.session_storage.date def date(self):
...
@date.setter @date.setter
@abc.abstractmethod
def date(self, val): def date(self, val):
self.session_storage.date = val ...
@property @property
@abc.abstractmethod
def is_bot(self): def is_bot(self):
return self.session_storage.is_bot ...
@is_bot.setter @is_bot.setter
def is_bot(self, val) -> int: @abc.abstractmethod
self.session_storage.is_bot = val def is_bot(self, val):
...
@property @property
def peers_by_id(self) -> Dict[str, int]: @abc.abstractmethod
return self.session_storage.peers_by_id def peers_by_id(self):
...
@property @property
def peers_by_username(self) -> Dict[str, int]: @abc.abstractmethod
return self.session_storage.peers_by_username def peers_by_username(self):
...
@property @property
def peers_by_phone(self) -> Dict[str, int]: @abc.abstractmethod
return self.session_storage.peers_by_phone def peers_by_phone(self):
...

View File

@ -1,60 +0,0 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc
from typing import Type
import pyrogram
class SessionDoesNotExist(Exception):
pass
class BaseSessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
self.client = client
self.session_data = session_data
self.dc_id = 1
self.test_mode = None
self.auth_key = None
self.user_id = None
self.date = 0
self.is_bot = False
self.peers_by_id = {}
self.peers_by_username = {}
self.peers_by_phone = {}
@abc.abstractmethod
def load_session(self):
...
@abc.abstractmethod
def save_session(self, sync=False):
...
@abc.abstractmethod
def sync_cleanup(self):
...
class BaseSessionConfig(abc.ABC):
@property
@abc.abstractmethod
def session_storage_cls(self) -> Type[BaseSessionStorage]:
...

View File

@ -22,21 +22,26 @@ import logging
import os import os
import shutil import shutil
import pyrogram
from ..ext import utils from ..ext import utils
from . import BaseSessionStorage, SessionDoesNotExist from . import MemorySessionStorage, SessionDoesNotExist
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class JsonSessionStorage(BaseSessionStorage): class JsonSessionStorage(MemorySessionStorage):
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str):
super(JsonSessionStorage, self).__init__(client)
self._session_name = session_name
def _get_file_name(self, name: str): def _get_file_name(self, name: str):
if not name.endswith('.session'): if not name.endswith('.session'):
name += '.session' name += '.session'
return os.path.join(self.client.workdir, name) return os.path.join(self._client.workdir, name)
def load_session(self): def load(self):
file_path = self._get_file_name(self.session_data) file_path = self._get_file_name(self._session_name)
log.info('Loading JSON session from {}'.format(file_path)) log.info('Loading JSON session from {}'.format(file_path))
try: try:
@ -45,59 +50,59 @@ class JsonSessionStorage(BaseSessionStorage):
except FileNotFoundError: except FileNotFoundError:
raise SessionDoesNotExist() raise SessionDoesNotExist()
self.dc_id = s["dc_id"] self._dc_id = s["dc_id"]
self.test_mode = s["test_mode"] self._test_mode = s["test_mode"]
self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key self._auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
self.user_id = s["user_id"] self._user_id = s["user_id"]
self.date = s.get("date", 0) self._date = s.get("date", 0)
self.is_bot = s.get('is_bot', self.client.is_bot) self._is_bot = s.get('is_bot', self._is_bot)
for k, v in s.get("peers_by_id", {}).items(): for k, v in s.get("peers_by_id", {}).items():
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v) self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
for k, v in s.get("peers_by_username", {}).items(): for k, v in s.get("peers_by_username", {}).items():
peer = self.peers_by_id.get(v, None) peer = self._peers_by_id.get(v, None)
if peer: if peer:
self.peers_by_username[k] = peer self._peers_by_username[k] = peer
for k, v in s.get("peers_by_phone", {}).items(): for k, v in s.get("peers_by_phone", {}).items():
peer = self.peers_by_id.get(v, None) peer = self._peers_by_id.get(v, None)
if peer: if peer:
self.peers_by_phone[k] = peer self._peers_by_phone[k] = peer
def save_session(self, sync=False): def save(self, sync=False):
file_path = self._get_file_name(self.session_data) file_path = self._get_file_name(self._session_name)
if sync: if sync:
file_path += '.tmp' file_path += '.tmp'
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync)) log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
auth_key = base64.b64encode(self.auth_key).decode() auth_key = base64.b64encode(self._auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
os.makedirs(self.client.workdir, exist_ok=True) os.makedirs(self._client.workdir, exist_ok=True)
data = { data = {
'dc_id': self.dc_id, 'dc_id': self._dc_id,
'test_mode': self.test_mode, 'test_mode': self._test_mode,
'auth_key': auth_key, 'auth_key': auth_key,
'user_id': self.user_id, 'user_id': self._user_id,
'date': self.date, 'date': self._date,
'is_bot': self.is_bot, 'is_bot': self._is_bot,
'peers_by_id': { 'peers_by_id': {
k: getattr(v, "access_hash", None) k: getattr(v, "access_hash", None)
for k, v in self.peers_by_id.copy().items() for k, v in self._peers_by_id.copy().items()
}, },
'peers_by_username': { 'peers_by_username': {
k: utils.get_peer_id(v) k: utils.get_peer_id(v)
for k, v in self.peers_by_username.copy().items() for k, v in self._peers_by_username.copy().items()
}, },
'peers_by_phone': { 'peers_by_phone': {
k: utils.get_peer_id(v) k: utils.get_peer_id(v)
for k, v in self.peers_by_phone.copy().items() for k, v in self._peers_by_phone.copy().items()
} }
} }
@ -109,10 +114,10 @@ class JsonSessionStorage(BaseSessionStorage):
# execution won't be here if an error has occurred earlier # execution won't be here if an error has occurred earlier
if sync: if sync:
shutil.move(file_path, self._get_file_name(self.session_data)) shutil.move(file_path, self._get_file_name(self._session_name))
def sync_cleanup(self): def sync_cleanup(self):
try: try:
os.remove(self._get_file_name(self.session_data) + '.tmp') os.remove(self._get_file_name(self._session_name) + '.tmp')
except OSError: except OSError:
pass pass

View File

@ -0,0 +1,85 @@
import pyrogram
from . import SessionStorage, SessionDoesNotExist
class MemorySessionStorage(SessionStorage):
def __init__(self, client: 'pyrogram.client.ext.BaseClient'):
super(MemorySessionStorage, self).__init__(client)
self._dc_id = 1
self._test_mode = None
self._auth_key = None
self._user_id = None
self._date = 0
self._is_bot = False
self._peers_by_id = {}
self._peers_by_username = {}
self._peers_by_phone = {}
def load(self):
raise SessionDoesNotExist()
def save(self, sync=False):
pass
def sync_cleanup(self):
pass
@property
def dc_id(self):
return self._dc_id
@dc_id.setter
def dc_id(self, val):
self._dc_id = val
@property
def test_mode(self):
return self._test_mode
@test_mode.setter
def test_mode(self, val):
self._test_mode = val
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, val):
self._auth_key = val
@property
def user_id(self):
return self._user_id
@user_id.setter
def user_id(self, val):
self._user_id = val
@property
def date(self):
return self._date
@date.setter
def date(self, val):
self._date = val
@property
def is_bot(self):
return self._is_bot
@is_bot.setter
def is_bot(self, val):
self._is_bot = val
@property
def peers_by_id(self):
return self._peers_by_id
@property
def peers_by_username(self):
return self._peers_by_username
@property
def peers_by_phone(self):
return self._peers_by_phone

View File

@ -2,10 +2,11 @@ import base64
import binascii import binascii
import struct import struct
from . import BaseSessionStorage, SessionDoesNotExist import pyrogram
from . import MemorySessionStorage, SessionDoesNotExist
class StringSessionStorage(BaseSessionStorage): class StringSessionStorage(MemorySessionStorage):
""" """
Packs session data as following (forcing little-endian byte order): Packs session data as following (forcing little-endian byte order):
Char dc_id (1 byte, unsigned) Char dc_id (1 byte, unsigned)
@ -18,22 +19,26 @@ class StringSessionStorage(BaseSessionStorage):
""" """
PACK_FORMAT = '<B?q?256s' PACK_FORMAT = '<B?q?256s'
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_string: str):
super(StringSessionStorage, self).__init__(client)
self._session_string = session_string
def _unpack(self, data): def _unpack(self, data):
return struct.unpack(self.PACK_FORMAT, data) return struct.unpack(self.PACK_FORMAT, data)
def _pack(self): def _pack(self):
return struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key) return struct.pack(self.PACK_FORMAT, self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key)
def load_session(self): def load(self):
try: try:
session_string = self.session_data[1:] session_string = self._session_string[1:]
session_string += '=' * (4 - len(session_string) % 4) # restore padding session_string += '=' * (4 - len(session_string) % 4) # restore padding
decoded = base64.b64decode(session_string, b'-_') decoded = base64.b64decode(session_string, b'-_')
self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key = self._unpack(decoded) self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key = self._unpack(decoded)
except (struct.error, binascii.Error): except (struct.error, binascii.Error):
raise SessionDoesNotExist() raise SessionDoesNotExist()
def save_session(self, sync=False): def save(self, sync=False):
if not sync: if not sync:
packed = self._pack() packed = self._pack()
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=') encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')

View File

@ -112,7 +112,8 @@ class Session:
def start(self): def start(self):
while True: while True:
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy) self.connection = Connection(self.dc_id, self.client.session_storage.test_mode,
self.client.ipv6, self.client.proxy)
try: try:
self.connection.connect() self.connection.connect()