mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 21:07:59 +00:00
add in-memory session storage, refactor session storages, remove mixin
This commit is contained in:
parent
9c4e9e166e
commit
5dc33c6337
@ -50,15 +50,15 @@ from pyrogram.api.errors import (
|
||||
from pyrogram.client.handlers import DisconnectHandler
|
||||
from pyrogram.client.handlers.handler import Handler
|
||||
from pyrogram.client.methods.password.utils import compute_check
|
||||
from pyrogram.client.session_storage import BaseSessionConfig
|
||||
from pyrogram.crypto import AES
|
||||
from pyrogram.session import Auth, Session
|
||||
from .dispatcher import Dispatcher
|
||||
from .ext import utils, Syncer, BaseClient
|
||||
from .methods import Methods
|
||||
from .session_storage import SessionDoesNotExist
|
||||
from .session_storage.json_session_storage import JsonSessionStorage
|
||||
from .session_storage.string_session_storage import StringSessionStorage
|
||||
from .session_storage import (
|
||||
SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage,
|
||||
StringSessionStorage
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -183,7 +183,7 @@ class Client(Methods, BaseClient):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
session_name: Union[str, BaseSessionConfig],
|
||||
session_name: Union[str, SessionStorage],
|
||||
api_id: Union[int, str] = None,
|
||||
api_hash: str = None,
|
||||
app_version: str = None,
|
||||
@ -209,14 +209,16 @@ class Client(Methods, BaseClient):
|
||||
takeout: bool = None):
|
||||
|
||||
if isinstance(session_name, str):
|
||||
if session_name.startswith(':'):
|
||||
if session_name == ':memory:':
|
||||
session_storage = MemorySessionStorage(self)
|
||||
elif session_name.startswith(':'):
|
||||
session_storage = StringSessionStorage(self, session_name)
|
||||
else:
|
||||
session_storage = JsonSessionStorage(self, session_name)
|
||||
elif isinstance(session_name, BaseSessionConfig):
|
||||
session_storage = session_name.session_storage_cls(self, session_name)
|
||||
elif isinstance(session_name, SessionStorage):
|
||||
session_storage = session_name
|
||||
else:
|
||||
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
|
||||
raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass')
|
||||
|
||||
super().__init__(session_storage)
|
||||
|
||||
@ -230,7 +232,7 @@ class Client(Methods, BaseClient):
|
||||
self.ipv6 = ipv6
|
||||
# TODO: Make code consistent, use underscore for private/protected fields
|
||||
self._proxy = proxy
|
||||
self.test_mode = test_mode
|
||||
self.session_storage.test_mode = test_mode
|
||||
self.phone_number = phone_number
|
||||
self.phone_code = phone_code
|
||||
self.password = password
|
||||
@ -282,10 +284,10 @@ class Client(Methods, BaseClient):
|
||||
raise ConnectionError("Client has already been started")
|
||||
|
||||
if isinstance(self.session_storage, JsonSessionStorage):
|
||||
if self.BOT_TOKEN_RE.match(self.session_storage.session_data):
|
||||
self.is_bot = True
|
||||
self.bot_token = self.session_storage.session_data
|
||||
self.session_storage.session_data = self.session_storage.session_data.split(":")[0]
|
||||
if self.BOT_TOKEN_RE.match(self.session_storage._session_name):
|
||||
self.session_storage.is_bot = True
|
||||
self.bot_token = self.session_storage._session_name
|
||||
self.session_storage._session_name = self.session_storage._session_name.split(":")[0]
|
||||
warnings.warn('\nYou are using a bot token as session name.\n'
|
||||
'It will be deprecated in next update, please use session file name to load '
|
||||
'existing sessions and bot_token argument to create new sessions.',
|
||||
@ -297,33 +299,33 @@ class Client(Methods, BaseClient):
|
||||
|
||||
self.session = Session(
|
||||
self,
|
||||
self.dc_id,
|
||||
self.auth_key
|
||||
self.session_storage.dc_id,
|
||||
self.session_storage.auth_key
|
||||
)
|
||||
|
||||
self.session.start()
|
||||
self.is_started = True
|
||||
|
||||
try:
|
||||
if self.user_id is None:
|
||||
if self.session_storage.user_id is None:
|
||||
if self.bot_token is None:
|
||||
self.authorize_user()
|
||||
else:
|
||||
self.is_bot = True
|
||||
self.session_storage.is_bot = True
|
||||
self.authorize_bot()
|
||||
|
||||
self.save_session()
|
||||
|
||||
if not self.is_bot:
|
||||
if not self.session_storage.is_bot:
|
||||
if self.takeout:
|
||||
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
|
||||
log.warning("Takeout session {} initiated".format(self.takeout_id))
|
||||
|
||||
now = time.time()
|
||||
|
||||
if abs(now - self.date) > Client.OFFLINE_SLEEP:
|
||||
self.peers_by_username.clear()
|
||||
self.peers_by_phone.clear()
|
||||
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
|
||||
self.session_storage.peers_by_username.clear()
|
||||
self.session_storage.peers_by_phone.clear()
|
||||
|
||||
self.get_initial_dialogs()
|
||||
self.get_contacts()
|
||||
@ -512,19 +514,20 @@ class Client(Methods, BaseClient):
|
||||
except UserMigrate as e:
|
||||
self.session.stop()
|
||||
|
||||
self.dc_id = e.x
|
||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
self.session_storage.dc_id = e.x
|
||||
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
|
||||
self.ipv6, self._proxy).create()
|
||||
|
||||
self.session = Session(
|
||||
self,
|
||||
self.dc_id,
|
||||
self.auth_key
|
||||
self.session_storage.dc_id,
|
||||
self.session_storage.auth_key
|
||||
)
|
||||
|
||||
self.session.start()
|
||||
self.authorize_bot()
|
||||
else:
|
||||
self.user_id = r.user.id
|
||||
self.session_storage.user_id = r.user.id
|
||||
|
||||
print("Logged in successfully as @{}".format(r.user.username))
|
||||
|
||||
@ -564,19 +567,19 @@ class Client(Methods, BaseClient):
|
||||
except (PhoneMigrate, NetworkMigrate) as e:
|
||||
self.session.stop()
|
||||
|
||||
self.dc_id = e.x
|
||||
self.session_storage.dc_id = e.x
|
||||
|
||||
self.auth_key = Auth(
|
||||
self.dc_id,
|
||||
self.test_mode,
|
||||
self.session_storage.auth_key = Auth(
|
||||
self.session_storage.dc_id,
|
||||
self.session_storage.test_mode,
|
||||
self.ipv6,
|
||||
self._proxy
|
||||
).create()
|
||||
|
||||
self.session = Session(
|
||||
self,
|
||||
self.dc_id,
|
||||
self.auth_key
|
||||
self.session_storage.dc_id,
|
||||
self.session_storage.auth_key
|
||||
)
|
||||
|
||||
self.session.start()
|
||||
@ -752,7 +755,7 @@ class Client(Methods, BaseClient):
|
||||
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
|
||||
|
||||
self.password = None
|
||||
self.user_id = r.user.id
|
||||
self.session_storage.user_id = r.user.id
|
||||
|
||||
print("Logged in successfully as {}".format(r.user.first_name))
|
||||
|
||||
@ -776,13 +779,13 @@ class Client(Methods, BaseClient):
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
self.peers_by_id[user_id] = input_peer
|
||||
self.session_storage.peers_by_id[user_id] = input_peer
|
||||
|
||||
if username is not None:
|
||||
self.peers_by_username[username.lower()] = input_peer
|
||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||
|
||||
if phone is not None:
|
||||
self.peers_by_phone[phone] = input_peer
|
||||
self.session_storage.peers_by_phone[phone] = input_peer
|
||||
|
||||
if isinstance(entity, (types.Chat, types.ChatForbidden)):
|
||||
chat_id = entity.id
|
||||
@ -792,7 +795,7 @@ class Client(Methods, BaseClient):
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
self.peers_by_id[peer_id] = input_peer
|
||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||
|
||||
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
||||
channel_id = entity.id
|
||||
@ -810,10 +813,10 @@ class Client(Methods, BaseClient):
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
self.peers_by_id[peer_id] = input_peer
|
||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||
|
||||
if username is not None:
|
||||
self.peers_by_username[username.lower()] = input_peer
|
||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||
|
||||
def download_worker(self):
|
||||
name = threading.current_thread().name
|
||||
@ -1127,10 +1130,11 @@ class Client(Methods, BaseClient):
|
||||
|
||||
def load_session(self):
|
||||
try:
|
||||
self.session_storage.load_session()
|
||||
self.session_storage.load()
|
||||
except SessionDoesNotExist:
|
||||
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
||||
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
|
||||
self.ipv6, self._proxy).create()
|
||||
|
||||
def load_plugins(self):
|
||||
if self.plugins.get("enabled", False):
|
||||
@ -1237,7 +1241,7 @@ class Client(Methods, BaseClient):
|
||||
log.warning('No plugin loaded from "{}"'.format(root))
|
||||
|
||||
def save_session(self):
|
||||
self.session_storage.save_session()
|
||||
self.session_storage.save()
|
||||
|
||||
def get_initial_dialogs_chunk(self,
|
||||
offset_date: int = 0):
|
||||
@ -1257,7 +1261,7 @@ class Client(Methods, BaseClient):
|
||||
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
||||
time.sleep(e.x)
|
||||
else:
|
||||
log.info("Total peers: {}".format(len(self.peers_by_id)))
|
||||
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
|
||||
return r
|
||||
|
||||
def get_initial_dialogs(self):
|
||||
@ -1293,7 +1297,7 @@ class Client(Methods, BaseClient):
|
||||
``KeyError`` in case the peer doesn't exist in the internal database.
|
||||
"""
|
||||
try:
|
||||
return self.peers_by_id[peer_id]
|
||||
return self.session_storage.peers_by_id[peer_id]
|
||||
except KeyError:
|
||||
if type(peer_id) is str:
|
||||
if peer_id in ("self", "me"):
|
||||
@ -1304,17 +1308,17 @@ class Client(Methods, BaseClient):
|
||||
try:
|
||||
int(peer_id)
|
||||
except ValueError:
|
||||
if peer_id not in self.peers_by_username:
|
||||
if peer_id not in self.session_storage.peers_by_username:
|
||||
self.send(
|
||||
functions.contacts.ResolveUsername(
|
||||
username=peer_id
|
||||
)
|
||||
)
|
||||
|
||||
return self.peers_by_username[peer_id]
|
||||
return self.session_storage.peers_by_username[peer_id]
|
||||
else:
|
||||
try:
|
||||
return self.peers_by_phone[peer_id]
|
||||
return self.session_storage.peers_by_phone[peer_id]
|
||||
except KeyError:
|
||||
raise PeerIdInvalid
|
||||
|
||||
@ -1341,7 +1345,7 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
|
||||
try:
|
||||
return self.peers_by_id[peer_id]
|
||||
return self.session_storage.peers_by_id[peer_id]
|
||||
except KeyError:
|
||||
raise PeerIdInvalid
|
||||
|
||||
@ -1411,7 +1415,7 @@ class Client(Methods, BaseClient):
|
||||
file_id = file_id or self.rnd_id()
|
||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||
|
||||
session = Session(self, self.dc_id, self.auth_key, is_media=True)
|
||||
session = Session(self, self.session_storage.dc_id, self.session_storage.auth_key, is_media=True)
|
||||
session.start()
|
||||
|
||||
try:
|
||||
@ -1492,7 +1496,7 @@ class Client(Methods, BaseClient):
|
||||
session = self.media_sessions.get(dc_id, None)
|
||||
|
||||
if session is None:
|
||||
if dc_id != self.dc_id:
|
||||
if dc_id != self.session_storage.dc_id:
|
||||
exported_auth = self.send(
|
||||
functions.auth.ExportAuthorization(
|
||||
dc_id=dc_id
|
||||
@ -1502,7 +1506,7 @@ class Client(Methods, BaseClient):
|
||||
session = Session(
|
||||
self,
|
||||
dc_id,
|
||||
Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
||||
Auth(dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
|
||||
is_media=True
|
||||
)
|
||||
|
||||
@ -1520,7 +1524,7 @@ class Client(Methods, BaseClient):
|
||||
session = Session(
|
||||
self,
|
||||
dc_id,
|
||||
self.auth_key,
|
||||
self.session_storage.auth_key,
|
||||
is_media=True
|
||||
)
|
||||
|
||||
@ -1588,7 +1592,7 @@ class Client(Methods, BaseClient):
|
||||
cdn_session = Session(
|
||||
self,
|
||||
r.dc_id,
|
||||
Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
||||
Auth(r.dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
|
||||
is_media=True,
|
||||
is_cdn=True
|
||||
)
|
||||
|
@ -24,10 +24,10 @@ from threading import Lock
|
||||
from pyrogram import __version__
|
||||
from ..style import Markdown, HTML
|
||||
from ...session.internals import MsgId
|
||||
from ..session_storage import SessionStorageMixin, BaseSessionStorage
|
||||
from ..session_storage import SessionStorage
|
||||
|
||||
|
||||
class BaseClient(SessionStorageMixin):
|
||||
class BaseClient:
|
||||
class StopTransmission(StopIteration):
|
||||
pass
|
||||
|
||||
@ -68,14 +68,14 @@ class BaseClient(SessionStorageMixin):
|
||||
13: "video_note"
|
||||
}
|
||||
|
||||
def __init__(self, session_storage: BaseSessionStorage):
|
||||
def __init__(self, session_storage: SessionStorage):
|
||||
self.session_storage = session_storage
|
||||
|
||||
self.rnd_id = MsgId
|
||||
self.channels_pts = {}
|
||||
|
||||
self.markdown = Markdown(self.peers_by_id)
|
||||
self.html = HTML(self.peers_by_id)
|
||||
self.markdown = Markdown(self.session_storage.peers_by_id)
|
||||
self.html = HTML(self.session_storage.peers_by_id)
|
||||
|
||||
self.session = None
|
||||
self.media_sessions = {}
|
||||
|
@ -81,9 +81,9 @@ class Syncer:
|
||||
|
||||
@classmethod
|
||||
def sync(cls, client):
|
||||
client.date = int(time.time())
|
||||
client.session_storage.date = int(time.time())
|
||||
try:
|
||||
client.session_storage.save_session(sync=True)
|
||||
client.session_storage.save(sync=True)
|
||||
except Exception as e:
|
||||
log.critical(e, exc_info=True)
|
||||
else:
|
||||
|
@ -44,5 +44,5 @@ class GetContacts(BaseClient):
|
||||
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
||||
time.sleep(e.x)
|
||||
else:
|
||||
log.info("Total contacts: {}".format(len(self.peers_by_phone)))
|
||||
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
|
||||
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
||||
|
@ -16,5 +16,7 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from .session_storage_mixin import SessionStorageMixin
|
||||
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
|
||||
from .abstract import SessionStorage, SessionDoesNotExist
|
||||
from .memory import MemorySessionStorage
|
||||
from .json import JsonSessionStorage
|
||||
from .string import StringSessionStorage
|
||||
|
@ -16,66 +16,103 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from typing import Dict
|
||||
import abc
|
||||
from typing import Type
|
||||
|
||||
import pyrogram
|
||||
|
||||
|
||||
class SessionStorageMixin:
|
||||
class SessionDoesNotExist(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SessionStorage(abc.ABC):
|
||||
def __init__(self, client: 'pyrogram.client.BaseClient'):
|
||||
self._client = client
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self, sync=False):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_cleanup(self):
|
||||
...
|
||||
|
||||
@property
|
||||
def dc_id(self) -> int:
|
||||
return self.session_storage.dc_id
|
||||
@abc.abstractmethod
|
||||
def dc_id(self):
|
||||
...
|
||||
|
||||
@dc_id.setter
|
||||
@abc.abstractmethod
|
||||
def dc_id(self, val):
|
||||
self.session_storage.dc_id = val
|
||||
...
|
||||
|
||||
@property
|
||||
def test_mode(self) -> bool:
|
||||
return self.session_storage.test_mode
|
||||
@abc.abstractmethod
|
||||
def test_mode(self):
|
||||
...
|
||||
|
||||
@test_mode.setter
|
||||
@abc.abstractmethod
|
||||
def test_mode(self, val):
|
||||
self.session_storage.test_mode = val
|
||||
...
|
||||
|
||||
@property
|
||||
def auth_key(self) -> bytes:
|
||||
return self.session_storage.auth_key
|
||||
@abc.abstractmethod
|
||||
def auth_key(self):
|
||||
...
|
||||
|
||||
@auth_key.setter
|
||||
@abc.abstractmethod
|
||||
def auth_key(self, val):
|
||||
self.session_storage.auth_key = val
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def user_id(self):
|
||||
return self.session_storage.user_id
|
||||
...
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, val) -> int:
|
||||
self.session_storage.user_id = val
|
||||
@abc.abstractmethod
|
||||
def user_id(self, val):
|
||||
...
|
||||
|
||||
@property
|
||||
def date(self) -> int:
|
||||
return self.session_storage.date
|
||||
@abc.abstractmethod
|
||||
def date(self):
|
||||
...
|
||||
|
||||
@date.setter
|
||||
@abc.abstractmethod
|
||||
def date(self, val):
|
||||
self.session_storage.date = val
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_bot(self):
|
||||
return self.session_storage.is_bot
|
||||
...
|
||||
|
||||
@is_bot.setter
|
||||
def is_bot(self, val) -> int:
|
||||
self.session_storage.is_bot = val
|
||||
@abc.abstractmethod
|
||||
def is_bot(self, val):
|
||||
...
|
||||
|
||||
@property
|
||||
def peers_by_id(self) -> Dict[str, int]:
|
||||
return self.session_storage.peers_by_id
|
||||
@abc.abstractmethod
|
||||
def peers_by_id(self):
|
||||
...
|
||||
|
||||
@property
|
||||
def peers_by_username(self) -> Dict[str, int]:
|
||||
return self.session_storage.peers_by_username
|
||||
@abc.abstractmethod
|
||||
def peers_by_username(self):
|
||||
...
|
||||
|
||||
@property
|
||||
def peers_by_phone(self) -> Dict[str, int]:
|
||||
return self.session_storage.peers_by_phone
|
||||
@abc.abstractmethod
|
||||
def peers_by_phone(self):
|
||||
...
|
@ -1,60 +0,0 @@
|
||||
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||
#
|
||||
# This file is part of Pyrogram.
|
||||
#
|
||||
# Pyrogram is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Pyrogram is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import abc
|
||||
from typing import Type
|
||||
|
||||
import pyrogram
|
||||
|
||||
|
||||
class SessionDoesNotExist(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSessionStorage(abc.ABC):
|
||||
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
|
||||
self.client = client
|
||||
self.session_data = session_data
|
||||
self.dc_id = 1
|
||||
self.test_mode = None
|
||||
self.auth_key = None
|
||||
self.user_id = None
|
||||
self.date = 0
|
||||
self.is_bot = False
|
||||
self.peers_by_id = {}
|
||||
self.peers_by_username = {}
|
||||
self.peers_by_phone = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_session(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_session(self, sync=False):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sync_cleanup(self):
|
||||
...
|
||||
|
||||
|
||||
class BaseSessionConfig(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def session_storage_cls(self) -> Type[BaseSessionStorage]:
|
||||
...
|
@ -22,21 +22,26 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pyrogram
|
||||
from ..ext import utils
|
||||
from . import BaseSessionStorage, SessionDoesNotExist
|
||||
from . import MemorySessionStorage, SessionDoesNotExist
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonSessionStorage(BaseSessionStorage):
|
||||
class JsonSessionStorage(MemorySessionStorage):
|
||||
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str):
|
||||
super(JsonSessionStorage, self).__init__(client)
|
||||
self._session_name = session_name
|
||||
|
||||
def _get_file_name(self, name: str):
|
||||
if not name.endswith('.session'):
|
||||
name += '.session'
|
||||
return os.path.join(self.client.workdir, name)
|
||||
return os.path.join(self._client.workdir, name)
|
||||
|
||||
def load_session(self):
|
||||
file_path = self._get_file_name(self.session_data)
|
||||
def load(self):
|
||||
file_path = self._get_file_name(self._session_name)
|
||||
log.info('Loading JSON session from {}'.format(file_path))
|
||||
|
||||
try:
|
||||
@ -45,59 +50,59 @@ class JsonSessionStorage(BaseSessionStorage):
|
||||
except FileNotFoundError:
|
||||
raise SessionDoesNotExist()
|
||||
|
||||
self.dc_id = s["dc_id"]
|
||||
self.test_mode = s["test_mode"]
|
||||
self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
|
||||
self.user_id = s["user_id"]
|
||||
self.date = s.get("date", 0)
|
||||
self.is_bot = s.get('is_bot', self.client.is_bot)
|
||||
self._dc_id = s["dc_id"]
|
||||
self._test_mode = s["test_mode"]
|
||||
self._auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
|
||||
self._user_id = s["user_id"]
|
||||
self._date = s.get("date", 0)
|
||||
self._is_bot = s.get('is_bot', self._is_bot)
|
||||
|
||||
for k, v in s.get("peers_by_id", {}).items():
|
||||
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
||||
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
||||
|
||||
for k, v in s.get("peers_by_username", {}).items():
|
||||
peer = self.peers_by_id.get(v, None)
|
||||
peer = self._peers_by_id.get(v, None)
|
||||
|
||||
if peer:
|
||||
self.peers_by_username[k] = peer
|
||||
self._peers_by_username[k] = peer
|
||||
|
||||
for k, v in s.get("peers_by_phone", {}).items():
|
||||
peer = self.peers_by_id.get(v, None)
|
||||
peer = self._peers_by_id.get(v, None)
|
||||
|
||||
if peer:
|
||||
self.peers_by_phone[k] = peer
|
||||
self._peers_by_phone[k] = peer
|
||||
|
||||
def save_session(self, sync=False):
|
||||
file_path = self._get_file_name(self.session_data)
|
||||
def save(self, sync=False):
|
||||
file_path = self._get_file_name(self._session_name)
|
||||
|
||||
if sync:
|
||||
file_path += '.tmp'
|
||||
|
||||
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
|
||||
|
||||
auth_key = base64.b64encode(self.auth_key).decode()
|
||||
auth_key = base64.b64encode(self._auth_key).decode()
|
||||
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
|
||||
|
||||
os.makedirs(self.client.workdir, exist_ok=True)
|
||||
os.makedirs(self._client.workdir, exist_ok=True)
|
||||
|
||||
data = {
|
||||
'dc_id': self.dc_id,
|
||||
'test_mode': self.test_mode,
|
||||
'dc_id': self._dc_id,
|
||||
'test_mode': self._test_mode,
|
||||
'auth_key': auth_key,
|
||||
'user_id': self.user_id,
|
||||
'date': self.date,
|
||||
'is_bot': self.is_bot,
|
||||
'user_id': self._user_id,
|
||||
'date': self._date,
|
||||
'is_bot': self._is_bot,
|
||||
'peers_by_id': {
|
||||
k: getattr(v, "access_hash", None)
|
||||
for k, v in self.peers_by_id.copy().items()
|
||||
for k, v in self._peers_by_id.copy().items()
|
||||
},
|
||||
'peers_by_username': {
|
||||
k: utils.get_peer_id(v)
|
||||
for k, v in self.peers_by_username.copy().items()
|
||||
for k, v in self._peers_by_username.copy().items()
|
||||
},
|
||||
'peers_by_phone': {
|
||||
k: utils.get_peer_id(v)
|
||||
for k, v in self.peers_by_phone.copy().items()
|
||||
for k, v in self._peers_by_phone.copy().items()
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,10 +114,10 @@ class JsonSessionStorage(BaseSessionStorage):
|
||||
|
||||
# execution won't be here if an error has occurred earlier
|
||||
if sync:
|
||||
shutil.move(file_path, self._get_file_name(self.session_data))
|
||||
shutil.move(file_path, self._get_file_name(self._session_name))
|
||||
|
||||
def sync_cleanup(self):
|
||||
try:
|
||||
os.remove(self._get_file_name(self.session_data) + '.tmp')
|
||||
os.remove(self._get_file_name(self._session_name) + '.tmp')
|
||||
except OSError:
|
||||
pass
|
85
pyrogram/client/session_storage/memory.py
Normal file
85
pyrogram/client/session_storage/memory.py
Normal file
@ -0,0 +1,85 @@
|
||||
import pyrogram
|
||||
from . import SessionStorage, SessionDoesNotExist
|
||||
|
||||
|
||||
class MemorySessionStorage(SessionStorage):
|
||||
def __init__(self, client: 'pyrogram.client.ext.BaseClient'):
|
||||
super(MemorySessionStorage, self).__init__(client)
|
||||
self._dc_id = 1
|
||||
self._test_mode = None
|
||||
self._auth_key = None
|
||||
self._user_id = None
|
||||
self._date = 0
|
||||
self._is_bot = False
|
||||
self._peers_by_id = {}
|
||||
self._peers_by_username = {}
|
||||
self._peers_by_phone = {}
|
||||
|
||||
def load(self):
|
||||
raise SessionDoesNotExist()
|
||||
|
||||
def save(self, sync=False):
|
||||
pass
|
||||
|
||||
def sync_cleanup(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def dc_id(self):
|
||||
return self._dc_id
|
||||
|
||||
@dc_id.setter
|
||||
def dc_id(self, val):
|
||||
self._dc_id = val
|
||||
|
||||
@property
|
||||
def test_mode(self):
|
||||
return self._test_mode
|
||||
|
||||
@test_mode.setter
|
||||
def test_mode(self, val):
|
||||
self._test_mode = val
|
||||
|
||||
@property
|
||||
def auth_key(self):
|
||||
return self._auth_key
|
||||
|
||||
@auth_key.setter
|
||||
def auth_key(self, val):
|
||||
self._auth_key = val
|
||||
|
||||
@property
|
||||
def user_id(self):
|
||||
return self._user_id
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, val):
|
||||
self._user_id = val
|
||||
|
||||
@property
|
||||
def date(self):
|
||||
return self._date
|
||||
|
||||
@date.setter
|
||||
def date(self, val):
|
||||
self._date = val
|
||||
|
||||
@property
|
||||
def is_bot(self):
|
||||
return self._is_bot
|
||||
|
||||
@is_bot.setter
|
||||
def is_bot(self, val):
|
||||
self._is_bot = val
|
||||
|
||||
@property
|
||||
def peers_by_id(self):
|
||||
return self._peers_by_id
|
||||
|
||||
@property
|
||||
def peers_by_username(self):
|
||||
return self._peers_by_username
|
||||
|
||||
@property
|
||||
def peers_by_phone(self):
|
||||
return self._peers_by_phone
|
@ -2,10 +2,11 @@ import base64
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
from . import BaseSessionStorage, SessionDoesNotExist
|
||||
import pyrogram
|
||||
from . import MemorySessionStorage, SessionDoesNotExist
|
||||
|
||||
|
||||
class StringSessionStorage(BaseSessionStorage):
|
||||
class StringSessionStorage(MemorySessionStorage):
|
||||
"""
|
||||
Packs session data as following (forcing little-endian byte order):
|
||||
Char dc_id (1 byte, unsigned)
|
||||
@ -18,22 +19,26 @@ class StringSessionStorage(BaseSessionStorage):
|
||||
"""
|
||||
PACK_FORMAT = '<B?q?256s'
|
||||
|
||||
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_string: str):
|
||||
super(StringSessionStorage, self).__init__(client)
|
||||
self._session_string = session_string
|
||||
|
||||
def _unpack(self, data):
|
||||
return struct.unpack(self.PACK_FORMAT, data)
|
||||
|
||||
def _pack(self):
|
||||
return struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key)
|
||||
return struct.pack(self.PACK_FORMAT, self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key)
|
||||
|
||||
def load_session(self):
|
||||
def load(self):
|
||||
try:
|
||||
session_string = self.session_data[1:]
|
||||
session_string = self._session_string[1:]
|
||||
session_string += '=' * (4 - len(session_string) % 4) # restore padding
|
||||
decoded = base64.b64decode(session_string, b'-_')
|
||||
self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key = self._unpack(decoded)
|
||||
self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key = self._unpack(decoded)
|
||||
except (struct.error, binascii.Error):
|
||||
raise SessionDoesNotExist()
|
||||
|
||||
def save_session(self, sync=False):
|
||||
def save(self, sync=False):
|
||||
if not sync:
|
||||
packed = self._pack()
|
||||
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
|
@ -112,7 +112,8 @@ class Session:
|
||||
|
||||
def start(self):
|
||||
while True:
|
||||
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
|
||||
self.connection = Connection(self.dc_id, self.client.session_storage.test_mode,
|
||||
self.client.ipv6, self.client.proxy)
|
||||
|
||||
try:
|
||||
self.connection.connect()
|
||||
|
Loading…
x
Reference in New Issue
Block a user