mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-28 21:07:59 +00:00
add in-memory session storage, refactor session storages, remove mixin
This commit is contained in:
parent
9c4e9e166e
commit
5dc33c6337
@ -50,15 +50,15 @@ from pyrogram.api.errors import (
|
|||||||
from pyrogram.client.handlers import DisconnectHandler
|
from pyrogram.client.handlers import DisconnectHandler
|
||||||
from pyrogram.client.handlers.handler import Handler
|
from pyrogram.client.handlers.handler import Handler
|
||||||
from pyrogram.client.methods.password.utils import compute_check
|
from pyrogram.client.methods.password.utils import compute_check
|
||||||
from pyrogram.client.session_storage import BaseSessionConfig
|
|
||||||
from pyrogram.crypto import AES
|
from pyrogram.crypto import AES
|
||||||
from pyrogram.session import Auth, Session
|
from pyrogram.session import Auth, Session
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
from .ext import utils, Syncer, BaseClient
|
from .ext import utils, Syncer, BaseClient
|
||||||
from .methods import Methods
|
from .methods import Methods
|
||||||
from .session_storage import SessionDoesNotExist
|
from .session_storage import (
|
||||||
from .session_storage.json_session_storage import JsonSessionStorage
|
SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage,
|
||||||
from .session_storage.string_session_storage import StringSessionStorage
|
StringSessionStorage
|
||||||
|
)
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ class Client(Methods, BaseClient):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
session_name: Union[str, BaseSessionConfig],
|
session_name: Union[str, SessionStorage],
|
||||||
api_id: Union[int, str] = None,
|
api_id: Union[int, str] = None,
|
||||||
api_hash: str = None,
|
api_hash: str = None,
|
||||||
app_version: str = None,
|
app_version: str = None,
|
||||||
@ -209,14 +209,16 @@ class Client(Methods, BaseClient):
|
|||||||
takeout: bool = None):
|
takeout: bool = None):
|
||||||
|
|
||||||
if isinstance(session_name, str):
|
if isinstance(session_name, str):
|
||||||
if session_name.startswith(':'):
|
if session_name == ':memory:':
|
||||||
|
session_storage = MemorySessionStorage(self)
|
||||||
|
elif session_name.startswith(':'):
|
||||||
session_storage = StringSessionStorage(self, session_name)
|
session_storage = StringSessionStorage(self, session_name)
|
||||||
else:
|
else:
|
||||||
session_storage = JsonSessionStorage(self, session_name)
|
session_storage = JsonSessionStorage(self, session_name)
|
||||||
elif isinstance(session_name, BaseSessionConfig):
|
elif isinstance(session_name, SessionStorage):
|
||||||
session_storage = session_name.session_storage_cls(self, session_name)
|
session_storage = session_name
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
|
raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass')
|
||||||
|
|
||||||
super().__init__(session_storage)
|
super().__init__(session_storage)
|
||||||
|
|
||||||
@ -230,7 +232,7 @@ class Client(Methods, BaseClient):
|
|||||||
self.ipv6 = ipv6
|
self.ipv6 = ipv6
|
||||||
# TODO: Make code consistent, use underscore for private/protected fields
|
# TODO: Make code consistent, use underscore for private/protected fields
|
||||||
self._proxy = proxy
|
self._proxy = proxy
|
||||||
self.test_mode = test_mode
|
self.session_storage.test_mode = test_mode
|
||||||
self.phone_number = phone_number
|
self.phone_number = phone_number
|
||||||
self.phone_code = phone_code
|
self.phone_code = phone_code
|
||||||
self.password = password
|
self.password = password
|
||||||
@ -282,10 +284,10 @@ class Client(Methods, BaseClient):
|
|||||||
raise ConnectionError("Client has already been started")
|
raise ConnectionError("Client has already been started")
|
||||||
|
|
||||||
if isinstance(self.session_storage, JsonSessionStorage):
|
if isinstance(self.session_storage, JsonSessionStorage):
|
||||||
if self.BOT_TOKEN_RE.match(self.session_storage.session_data):
|
if self.BOT_TOKEN_RE.match(self.session_storage._session_name):
|
||||||
self.is_bot = True
|
self.session_storage.is_bot = True
|
||||||
self.bot_token = self.session_storage.session_data
|
self.bot_token = self.session_storage._session_name
|
||||||
self.session_storage.session_data = self.session_storage.session_data.split(":")[0]
|
self.session_storage._session_name = self.session_storage._session_name.split(":")[0]
|
||||||
warnings.warn('\nYou are using a bot token as session name.\n'
|
warnings.warn('\nYou are using a bot token as session name.\n'
|
||||||
'It will be deprecated in next update, please use session file name to load '
|
'It will be deprecated in next update, please use session file name to load '
|
||||||
'existing sessions and bot_token argument to create new sessions.',
|
'existing sessions and bot_token argument to create new sessions.',
|
||||||
@ -297,33 +299,33 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self,
|
self,
|
||||||
self.dc_id,
|
self.session_storage.dc_id,
|
||||||
self.auth_key
|
self.session_storage.auth_key
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.start()
|
self.session.start()
|
||||||
self.is_started = True
|
self.is_started = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.user_id is None:
|
if self.session_storage.user_id is None:
|
||||||
if self.bot_token is None:
|
if self.bot_token is None:
|
||||||
self.authorize_user()
|
self.authorize_user()
|
||||||
else:
|
else:
|
||||||
self.is_bot = True
|
self.session_storage.is_bot = True
|
||||||
self.authorize_bot()
|
self.authorize_bot()
|
||||||
|
|
||||||
self.save_session()
|
self.save_session()
|
||||||
|
|
||||||
if not self.is_bot:
|
if not self.session_storage.is_bot:
|
||||||
if self.takeout:
|
if self.takeout:
|
||||||
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
|
self.takeout_id = self.send(functions.account.InitTakeoutSession()).id
|
||||||
log.warning("Takeout session {} initiated".format(self.takeout_id))
|
log.warning("Takeout session {} initiated".format(self.takeout_id))
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
if abs(now - self.date) > Client.OFFLINE_SLEEP:
|
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
|
||||||
self.peers_by_username.clear()
|
self.session_storage.peers_by_username.clear()
|
||||||
self.peers_by_phone.clear()
|
self.session_storage.peers_by_phone.clear()
|
||||||
|
|
||||||
self.get_initial_dialogs()
|
self.get_initial_dialogs()
|
||||||
self.get_contacts()
|
self.get_contacts()
|
||||||
@ -512,19 +514,20 @@ class Client(Methods, BaseClient):
|
|||||||
except UserMigrate as e:
|
except UserMigrate as e:
|
||||||
self.session.stop()
|
self.session.stop()
|
||||||
|
|
||||||
self.dc_id = e.x
|
self.session_storage.dc_id = e.x
|
||||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
|
||||||
|
self.ipv6, self._proxy).create()
|
||||||
|
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self,
|
self,
|
||||||
self.dc_id,
|
self.session_storage.dc_id,
|
||||||
self.auth_key
|
self.session_storage.auth_key
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.start()
|
self.session.start()
|
||||||
self.authorize_bot()
|
self.authorize_bot()
|
||||||
else:
|
else:
|
||||||
self.user_id = r.user.id
|
self.session_storage.user_id = r.user.id
|
||||||
|
|
||||||
print("Logged in successfully as @{}".format(r.user.username))
|
print("Logged in successfully as @{}".format(r.user.username))
|
||||||
|
|
||||||
@ -564,19 +567,19 @@ class Client(Methods, BaseClient):
|
|||||||
except (PhoneMigrate, NetworkMigrate) as e:
|
except (PhoneMigrate, NetworkMigrate) as e:
|
||||||
self.session.stop()
|
self.session.stop()
|
||||||
|
|
||||||
self.dc_id = e.x
|
self.session_storage.dc_id = e.x
|
||||||
|
|
||||||
self.auth_key = Auth(
|
self.session_storage.auth_key = Auth(
|
||||||
self.dc_id,
|
self.session_storage.dc_id,
|
||||||
self.test_mode,
|
self.session_storage.test_mode,
|
||||||
self.ipv6,
|
self.ipv6,
|
||||||
self._proxy
|
self._proxy
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self,
|
self,
|
||||||
self.dc_id,
|
self.session_storage.dc_id,
|
||||||
self.auth_key
|
self.session_storage.auth_key
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session.start()
|
self.session.start()
|
||||||
@ -752,7 +755,7 @@ class Client(Methods, BaseClient):
|
|||||||
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
|
assert self.send(functions.help.AcceptTermsOfService(terms_of_service.id))
|
||||||
|
|
||||||
self.password = None
|
self.password = None
|
||||||
self.user_id = r.user.id
|
self.session_storage.user_id = r.user.id
|
||||||
|
|
||||||
print("Logged in successfully as {}".format(r.user.first_name))
|
print("Logged in successfully as {}".format(r.user.first_name))
|
||||||
|
|
||||||
@ -776,13 +779,13 @@ class Client(Methods, BaseClient):
|
|||||||
access_hash=access_hash
|
access_hash=access_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
self.peers_by_id[user_id] = input_peer
|
self.session_storage.peers_by_id[user_id] = input_peer
|
||||||
|
|
||||||
if username is not None:
|
if username is not None:
|
||||||
self.peers_by_username[username.lower()] = input_peer
|
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||||
|
|
||||||
if phone is not None:
|
if phone is not None:
|
||||||
self.peers_by_phone[phone] = input_peer
|
self.session_storage.peers_by_phone[phone] = input_peer
|
||||||
|
|
||||||
if isinstance(entity, (types.Chat, types.ChatForbidden)):
|
if isinstance(entity, (types.Chat, types.ChatForbidden)):
|
||||||
chat_id = entity.id
|
chat_id = entity.id
|
||||||
@ -792,7 +795,7 @@ class Client(Methods, BaseClient):
|
|||||||
chat_id=chat_id
|
chat_id=chat_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.peers_by_id[peer_id] = input_peer
|
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||||
|
|
||||||
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
||||||
channel_id = entity.id
|
channel_id = entity.id
|
||||||
@ -810,10 +813,10 @@ class Client(Methods, BaseClient):
|
|||||||
access_hash=access_hash
|
access_hash=access_hash
|
||||||
)
|
)
|
||||||
|
|
||||||
self.peers_by_id[peer_id] = input_peer
|
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||||
|
|
||||||
if username is not None:
|
if username is not None:
|
||||||
self.peers_by_username[username.lower()] = input_peer
|
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||||
|
|
||||||
def download_worker(self):
|
def download_worker(self):
|
||||||
name = threading.current_thread().name
|
name = threading.current_thread().name
|
||||||
@ -1127,10 +1130,11 @@ class Client(Methods, BaseClient):
|
|||||||
|
|
||||||
def load_session(self):
|
def load_session(self):
|
||||||
try:
|
try:
|
||||||
self.session_storage.load_session()
|
self.session_storage.load()
|
||||||
except SessionDoesNotExist:
|
except SessionDoesNotExist:
|
||||||
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
log.info('Could not load session "{}", initiate new one'.format(self.session_name))
|
||||||
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
|
self.session_storage.auth_key = Auth(self.session_storage.dc_id, self.session_storage.test_mode,
|
||||||
|
self.ipv6, self._proxy).create()
|
||||||
|
|
||||||
def load_plugins(self):
|
def load_plugins(self):
|
||||||
if self.plugins.get("enabled", False):
|
if self.plugins.get("enabled", False):
|
||||||
@ -1237,7 +1241,7 @@ class Client(Methods, BaseClient):
|
|||||||
log.warning('No plugin loaded from "{}"'.format(root))
|
log.warning('No plugin loaded from "{}"'.format(root))
|
||||||
|
|
||||||
def save_session(self):
|
def save_session(self):
|
||||||
self.session_storage.save_session()
|
self.session_storage.save()
|
||||||
|
|
||||||
def get_initial_dialogs_chunk(self,
|
def get_initial_dialogs_chunk(self,
|
||||||
offset_date: int = 0):
|
offset_date: int = 0):
|
||||||
@ -1257,7 +1261,7 @@ class Client(Methods, BaseClient):
|
|||||||
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
||||||
time.sleep(e.x)
|
time.sleep(e.x)
|
||||||
else:
|
else:
|
||||||
log.info("Total peers: {}".format(len(self.peers_by_id)))
|
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def get_initial_dialogs(self):
|
def get_initial_dialogs(self):
|
||||||
@ -1293,7 +1297,7 @@ class Client(Methods, BaseClient):
|
|||||||
``KeyError`` in case the peer doesn't exist in the internal database.
|
``KeyError`` in case the peer doesn't exist in the internal database.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.peers_by_id[peer_id]
|
return self.session_storage.peers_by_id[peer_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if type(peer_id) is str:
|
if type(peer_id) is str:
|
||||||
if peer_id in ("self", "me"):
|
if peer_id in ("self", "me"):
|
||||||
@ -1304,17 +1308,17 @@ class Client(Methods, BaseClient):
|
|||||||
try:
|
try:
|
||||||
int(peer_id)
|
int(peer_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if peer_id not in self.peers_by_username:
|
if peer_id not in self.session_storage.peers_by_username:
|
||||||
self.send(
|
self.send(
|
||||||
functions.contacts.ResolveUsername(
|
functions.contacts.ResolveUsername(
|
||||||
username=peer_id
|
username=peer_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.peers_by_username[peer_id]
|
return self.session_storage.peers_by_username[peer_id]
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return self.peers_by_phone[peer_id]
|
return self.session_storage.peers_by_phone[peer_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise PeerIdInvalid
|
raise PeerIdInvalid
|
||||||
|
|
||||||
@ -1341,7 +1345,7 @@ class Client(Methods, BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.peers_by_id[peer_id]
|
return self.session_storage.peers_by_id[peer_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise PeerIdInvalid
|
raise PeerIdInvalid
|
||||||
|
|
||||||
@ -1411,7 +1415,7 @@ class Client(Methods, BaseClient):
|
|||||||
file_id = file_id or self.rnd_id()
|
file_id = file_id or self.rnd_id()
|
||||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||||
|
|
||||||
session = Session(self, self.dc_id, self.auth_key, is_media=True)
|
session = Session(self, self.session_storage.dc_id, self.session_storage.auth_key, is_media=True)
|
||||||
session.start()
|
session.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1492,7 +1496,7 @@ class Client(Methods, BaseClient):
|
|||||||
session = self.media_sessions.get(dc_id, None)
|
session = self.media_sessions.get(dc_id, None)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
if dc_id != self.dc_id:
|
if dc_id != self.session_storage.dc_id:
|
||||||
exported_auth = self.send(
|
exported_auth = self.send(
|
||||||
functions.auth.ExportAuthorization(
|
functions.auth.ExportAuthorization(
|
||||||
dc_id=dc_id
|
dc_id=dc_id
|
||||||
@ -1502,7 +1506,7 @@ class Client(Methods, BaseClient):
|
|||||||
session = Session(
|
session = Session(
|
||||||
self,
|
self,
|
||||||
dc_id,
|
dc_id,
|
||||||
Auth(dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
Auth(dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
|
||||||
is_media=True
|
is_media=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1520,7 +1524,7 @@ class Client(Methods, BaseClient):
|
|||||||
session = Session(
|
session = Session(
|
||||||
self,
|
self,
|
||||||
dc_id,
|
dc_id,
|
||||||
self.auth_key,
|
self.session_storage.auth_key,
|
||||||
is_media=True
|
is_media=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1588,7 +1592,7 @@ class Client(Methods, BaseClient):
|
|||||||
cdn_session = Session(
|
cdn_session = Session(
|
||||||
self,
|
self,
|
||||||
r.dc_id,
|
r.dc_id,
|
||||||
Auth(r.dc_id, self.test_mode, self.ipv6, self._proxy).create(),
|
Auth(r.dc_id, self.session_storage.test_mode, self.ipv6, self._proxy).create(),
|
||||||
is_media=True,
|
is_media=True,
|
||||||
is_cdn=True
|
is_cdn=True
|
||||||
)
|
)
|
||||||
|
@ -24,10 +24,10 @@ from threading import Lock
|
|||||||
from pyrogram import __version__
|
from pyrogram import __version__
|
||||||
from ..style import Markdown, HTML
|
from ..style import Markdown, HTML
|
||||||
from ...session.internals import MsgId
|
from ...session.internals import MsgId
|
||||||
from ..session_storage import SessionStorageMixin, BaseSessionStorage
|
from ..session_storage import SessionStorage
|
||||||
|
|
||||||
|
|
||||||
class BaseClient(SessionStorageMixin):
|
class BaseClient:
|
||||||
class StopTransmission(StopIteration):
|
class StopTransmission(StopIteration):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,14 +68,14 @@ class BaseClient(SessionStorageMixin):
|
|||||||
13: "video_note"
|
13: "video_note"
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, session_storage: BaseSessionStorage):
|
def __init__(self, session_storage: SessionStorage):
|
||||||
self.session_storage = session_storage
|
self.session_storage = session_storage
|
||||||
|
|
||||||
self.rnd_id = MsgId
|
self.rnd_id = MsgId
|
||||||
self.channels_pts = {}
|
self.channels_pts = {}
|
||||||
|
|
||||||
self.markdown = Markdown(self.peers_by_id)
|
self.markdown = Markdown(self.session_storage.peers_by_id)
|
||||||
self.html = HTML(self.peers_by_id)
|
self.html = HTML(self.session_storage.peers_by_id)
|
||||||
|
|
||||||
self.session = None
|
self.session = None
|
||||||
self.media_sessions = {}
|
self.media_sessions = {}
|
||||||
|
@ -81,9 +81,9 @@ class Syncer:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sync(cls, client):
|
def sync(cls, client):
|
||||||
client.date = int(time.time())
|
client.session_storage.date = int(time.time())
|
||||||
try:
|
try:
|
||||||
client.session_storage.save_session(sync=True)
|
client.session_storage.save(sync=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.critical(e, exc_info=True)
|
log.critical(e, exc_info=True)
|
||||||
else:
|
else:
|
||||||
|
@ -44,5 +44,5 @@ class GetContacts(BaseClient):
|
|||||||
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
||||||
time.sleep(e.x)
|
time.sleep(e.x)
|
||||||
else:
|
else:
|
||||||
log.info("Total contacts: {}".format(len(self.peers_by_phone)))
|
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
|
||||||
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
||||||
|
@ -16,5 +16,7 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from .session_storage_mixin import SessionStorageMixin
|
from .abstract import SessionStorage, SessionDoesNotExist
|
||||||
from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
|
from .memory import MemorySessionStorage
|
||||||
|
from .json import JsonSessionStorage
|
||||||
|
from .string import StringSessionStorage
|
||||||
|
@ -16,66 +16,103 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from typing import Dict
|
import abc
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pyrogram
|
||||||
|
|
||||||
|
|
||||||
class SessionStorageMixin:
|
class SessionDoesNotExist(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SessionStorage(abc.ABC):
|
||||||
|
def __init__(self, client: 'pyrogram.client.BaseClient'):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def save(self, sync=False):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def sync_cleanup(self):
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dc_id(self) -> int:
|
@abc.abstractmethod
|
||||||
return self.session_storage.dc_id
|
def dc_id(self):
|
||||||
|
...
|
||||||
|
|
||||||
@dc_id.setter
|
@dc_id.setter
|
||||||
|
@abc.abstractmethod
|
||||||
def dc_id(self, val):
|
def dc_id(self, val):
|
||||||
self.session_storage.dc_id = val
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def test_mode(self) -> bool:
|
@abc.abstractmethod
|
||||||
return self.session_storage.test_mode
|
def test_mode(self):
|
||||||
|
...
|
||||||
|
|
||||||
@test_mode.setter
|
@test_mode.setter
|
||||||
|
@abc.abstractmethod
|
||||||
def test_mode(self, val):
|
def test_mode(self, val):
|
||||||
self.session_storage.test_mode = val
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def auth_key(self) -> bytes:
|
@abc.abstractmethod
|
||||||
return self.session_storage.auth_key
|
def auth_key(self):
|
||||||
|
...
|
||||||
|
|
||||||
@auth_key.setter
|
@auth_key.setter
|
||||||
|
@abc.abstractmethod
|
||||||
def auth_key(self, val):
|
def auth_key(self, val):
|
||||||
self.session_storage.auth_key = val
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def user_id(self):
|
def user_id(self):
|
||||||
return self.session_storage.user_id
|
...
|
||||||
|
|
||||||
@user_id.setter
|
@user_id.setter
|
||||||
def user_id(self, val) -> int:
|
@abc.abstractmethod
|
||||||
self.session_storage.user_id = val
|
def user_id(self, val):
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def date(self) -> int:
|
@abc.abstractmethod
|
||||||
return self.session_storage.date
|
def date(self):
|
||||||
|
...
|
||||||
|
|
||||||
@date.setter
|
@date.setter
|
||||||
|
@abc.abstractmethod
|
||||||
def date(self, val):
|
def date(self, val):
|
||||||
self.session_storage.date = val
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
def is_bot(self):
|
def is_bot(self):
|
||||||
return self.session_storage.is_bot
|
...
|
||||||
|
|
||||||
@is_bot.setter
|
@is_bot.setter
|
||||||
def is_bot(self, val) -> int:
|
@abc.abstractmethod
|
||||||
self.session_storage.is_bot = val
|
def is_bot(self, val):
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def peers_by_id(self) -> Dict[str, int]:
|
@abc.abstractmethod
|
||||||
return self.session_storage.peers_by_id
|
def peers_by_id(self):
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def peers_by_username(self) -> Dict[str, int]:
|
@abc.abstractmethod
|
||||||
return self.session_storage.peers_by_username
|
def peers_by_username(self):
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def peers_by_phone(self) -> Dict[str, int]:
|
@abc.abstractmethod
|
||||||
return self.session_storage.peers_by_phone
|
def peers_by_phone(self):
|
||||||
|
...
|
@ -1,60 +0,0 @@
|
|||||||
# Pyrogram - Telegram MTProto API Client Library for Python
|
|
||||||
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
|
||||||
#
|
|
||||||
# This file is part of Pyrogram.
|
|
||||||
#
|
|
||||||
# Pyrogram is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU Lesser General Public License as published
|
|
||||||
# by the Free Software Foundation, either version 3 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
#
|
|
||||||
# Pyrogram is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU Lesser General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import pyrogram
|
|
||||||
|
|
||||||
|
|
||||||
class SessionDoesNotExist(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSessionStorage(abc.ABC):
|
|
||||||
def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
|
|
||||||
self.client = client
|
|
||||||
self.session_data = session_data
|
|
||||||
self.dc_id = 1
|
|
||||||
self.test_mode = None
|
|
||||||
self.auth_key = None
|
|
||||||
self.user_id = None
|
|
||||||
self.date = 0
|
|
||||||
self.is_bot = False
|
|
||||||
self.peers_by_id = {}
|
|
||||||
self.peers_by_username = {}
|
|
||||||
self.peers_by_phone = {}
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def load_session(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def save_session(self, sync=False):
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sync_cleanup(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSessionConfig(abc.ABC):
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def session_storage_cls(self) -> Type[BaseSessionStorage]:
|
|
||||||
...
|
|
@ -22,21 +22,26 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
import pyrogram
|
||||||
from ..ext import utils
|
from ..ext import utils
|
||||||
from . import BaseSessionStorage, SessionDoesNotExist
|
from . import MemorySessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JsonSessionStorage(BaseSessionStorage):
|
class JsonSessionStorage(MemorySessionStorage):
|
||||||
|
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str):
|
||||||
|
super(JsonSessionStorage, self).__init__(client)
|
||||||
|
self._session_name = session_name
|
||||||
|
|
||||||
def _get_file_name(self, name: str):
|
def _get_file_name(self, name: str):
|
||||||
if not name.endswith('.session'):
|
if not name.endswith('.session'):
|
||||||
name += '.session'
|
name += '.session'
|
||||||
return os.path.join(self.client.workdir, name)
|
return os.path.join(self._client.workdir, name)
|
||||||
|
|
||||||
def load_session(self):
|
def load(self):
|
||||||
file_path = self._get_file_name(self.session_data)
|
file_path = self._get_file_name(self._session_name)
|
||||||
log.info('Loading JSON session from {}'.format(file_path))
|
log.info('Loading JSON session from {}'.format(file_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -45,59 +50,59 @@ class JsonSessionStorage(BaseSessionStorage):
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise SessionDoesNotExist()
|
raise SessionDoesNotExist()
|
||||||
|
|
||||||
self.dc_id = s["dc_id"]
|
self._dc_id = s["dc_id"]
|
||||||
self.test_mode = s["test_mode"]
|
self._test_mode = s["test_mode"]
|
||||||
self.auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
|
self._auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
|
||||||
self.user_id = s["user_id"]
|
self._user_id = s["user_id"]
|
||||||
self.date = s.get("date", 0)
|
self._date = s.get("date", 0)
|
||||||
self.is_bot = s.get('is_bot', self.client.is_bot)
|
self._is_bot = s.get('is_bot', self._is_bot)
|
||||||
|
|
||||||
for k, v in s.get("peers_by_id", {}).items():
|
for k, v in s.get("peers_by_id", {}).items():
|
||||||
self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
||||||
|
|
||||||
for k, v in s.get("peers_by_username", {}).items():
|
for k, v in s.get("peers_by_username", {}).items():
|
||||||
peer = self.peers_by_id.get(v, None)
|
peer = self._peers_by_id.get(v, None)
|
||||||
|
|
||||||
if peer:
|
if peer:
|
||||||
self.peers_by_username[k] = peer
|
self._peers_by_username[k] = peer
|
||||||
|
|
||||||
for k, v in s.get("peers_by_phone", {}).items():
|
for k, v in s.get("peers_by_phone", {}).items():
|
||||||
peer = self.peers_by_id.get(v, None)
|
peer = self._peers_by_id.get(v, None)
|
||||||
|
|
||||||
if peer:
|
if peer:
|
||||||
self.peers_by_phone[k] = peer
|
self._peers_by_phone[k] = peer
|
||||||
|
|
||||||
def save_session(self, sync=False):
|
def save(self, sync=False):
|
||||||
file_path = self._get_file_name(self.session_data)
|
file_path = self._get_file_name(self._session_name)
|
||||||
|
|
||||||
if sync:
|
if sync:
|
||||||
file_path += '.tmp'
|
file_path += '.tmp'
|
||||||
|
|
||||||
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
|
log.info('Saving JSON session to {}, sync={}'.format(file_path, sync))
|
||||||
|
|
||||||
auth_key = base64.b64encode(self.auth_key).decode()
|
auth_key = base64.b64encode(self._auth_key).decode()
|
||||||
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
|
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] # split key in lines of 43 chars
|
||||||
|
|
||||||
os.makedirs(self.client.workdir, exist_ok=True)
|
os.makedirs(self._client.workdir, exist_ok=True)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'dc_id': self.dc_id,
|
'dc_id': self._dc_id,
|
||||||
'test_mode': self.test_mode,
|
'test_mode': self._test_mode,
|
||||||
'auth_key': auth_key,
|
'auth_key': auth_key,
|
||||||
'user_id': self.user_id,
|
'user_id': self._user_id,
|
||||||
'date': self.date,
|
'date': self._date,
|
||||||
'is_bot': self.is_bot,
|
'is_bot': self._is_bot,
|
||||||
'peers_by_id': {
|
'peers_by_id': {
|
||||||
k: getattr(v, "access_hash", None)
|
k: getattr(v, "access_hash", None)
|
||||||
for k, v in self.peers_by_id.copy().items()
|
for k, v in self._peers_by_id.copy().items()
|
||||||
},
|
},
|
||||||
'peers_by_username': {
|
'peers_by_username': {
|
||||||
k: utils.get_peer_id(v)
|
k: utils.get_peer_id(v)
|
||||||
for k, v in self.peers_by_username.copy().items()
|
for k, v in self._peers_by_username.copy().items()
|
||||||
},
|
},
|
||||||
'peers_by_phone': {
|
'peers_by_phone': {
|
||||||
k: utils.get_peer_id(v)
|
k: utils.get_peer_id(v)
|
||||||
for k, v in self.peers_by_phone.copy().items()
|
for k, v in self._peers_by_phone.copy().items()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,10 +114,10 @@ class JsonSessionStorage(BaseSessionStorage):
|
|||||||
|
|
||||||
# execution won't be here if an error has occurred earlier
|
# execution won't be here if an error has occurred earlier
|
||||||
if sync:
|
if sync:
|
||||||
shutil.move(file_path, self._get_file_name(self.session_data))
|
shutil.move(file_path, self._get_file_name(self._session_name))
|
||||||
|
|
||||||
def sync_cleanup(self):
|
def sync_cleanup(self):
|
||||||
try:
|
try:
|
||||||
os.remove(self._get_file_name(self.session_data) + '.tmp')
|
os.remove(self._get_file_name(self._session_name) + '.tmp')
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
85
pyrogram/client/session_storage/memory.py
Normal file
85
pyrogram/client/session_storage/memory.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import pyrogram
|
||||||
|
from . import SessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySessionStorage(SessionStorage):
|
||||||
|
def __init__(self, client: 'pyrogram.client.ext.BaseClient'):
|
||||||
|
super(MemorySessionStorage, self).__init__(client)
|
||||||
|
self._dc_id = 1
|
||||||
|
self._test_mode = None
|
||||||
|
self._auth_key = None
|
||||||
|
self._user_id = None
|
||||||
|
self._date = 0
|
||||||
|
self._is_bot = False
|
||||||
|
self._peers_by_id = {}
|
||||||
|
self._peers_by_username = {}
|
||||||
|
self._peers_by_phone = {}
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
raise SessionDoesNotExist()
|
||||||
|
|
||||||
|
def save(self, sync=False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sync_cleanup(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dc_id(self):
|
||||||
|
return self._dc_id
|
||||||
|
|
||||||
|
@dc_id.setter
|
||||||
|
def dc_id(self, val):
|
||||||
|
self._dc_id = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_mode(self):
|
||||||
|
return self._test_mode
|
||||||
|
|
||||||
|
@test_mode.setter
|
||||||
|
def test_mode(self, val):
|
||||||
|
self._test_mode = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_key(self):
|
||||||
|
return self._auth_key
|
||||||
|
|
||||||
|
@auth_key.setter
|
||||||
|
def auth_key(self, val):
|
||||||
|
self._auth_key = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self):
|
||||||
|
return self._user_id
|
||||||
|
|
||||||
|
@user_id.setter
|
||||||
|
def user_id(self, val):
|
||||||
|
self._user_id = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def date(self):
|
||||||
|
return self._date
|
||||||
|
|
||||||
|
@date.setter
|
||||||
|
def date(self, val):
|
||||||
|
self._date = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_bot(self):
|
||||||
|
return self._is_bot
|
||||||
|
|
||||||
|
@is_bot.setter
|
||||||
|
def is_bot(self, val):
|
||||||
|
self._is_bot = val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_id(self):
|
||||||
|
return self._peers_by_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_username(self):
|
||||||
|
return self._peers_by_username
|
||||||
|
|
||||||
|
@property
|
||||||
|
def peers_by_phone(self):
|
||||||
|
return self._peers_by_phone
|
@ -2,10 +2,11 @@ import base64
|
|||||||
import binascii
|
import binascii
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from . import BaseSessionStorage, SessionDoesNotExist
|
import pyrogram
|
||||||
|
from . import MemorySessionStorage, SessionDoesNotExist
|
||||||
|
|
||||||
|
|
||||||
class StringSessionStorage(BaseSessionStorage):
|
class StringSessionStorage(MemorySessionStorage):
|
||||||
"""
|
"""
|
||||||
Packs session data as following (forcing little-endian byte order):
|
Packs session data as following (forcing little-endian byte order):
|
||||||
Char dc_id (1 byte, unsigned)
|
Char dc_id (1 byte, unsigned)
|
||||||
@ -18,22 +19,26 @@ class StringSessionStorage(BaseSessionStorage):
|
|||||||
"""
|
"""
|
||||||
PACK_FORMAT = '<B?q?256s'
|
PACK_FORMAT = '<B?q?256s'
|
||||||
|
|
||||||
|
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_string: str):
|
||||||
|
super(StringSessionStorage, self).__init__(client)
|
||||||
|
self._session_string = session_string
|
||||||
|
|
||||||
def _unpack(self, data):
|
def _unpack(self, data):
|
||||||
return struct.unpack(self.PACK_FORMAT, data)
|
return struct.unpack(self.PACK_FORMAT, data)
|
||||||
|
|
||||||
def _pack(self):
|
def _pack(self):
|
||||||
return struct.pack(self.PACK_FORMAT, self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key)
|
return struct.pack(self.PACK_FORMAT, self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key)
|
||||||
|
|
||||||
def load_session(self):
|
def load(self):
|
||||||
try:
|
try:
|
||||||
session_string = self.session_data[1:]
|
session_string = self._session_string[1:]
|
||||||
session_string += '=' * (4 - len(session_string) % 4) # restore padding
|
session_string += '=' * (4 - len(session_string) % 4) # restore padding
|
||||||
decoded = base64.b64decode(session_string, b'-_')
|
decoded = base64.b64decode(session_string, b'-_')
|
||||||
self.dc_id, self.test_mode, self.user_id, self.is_bot, self.auth_key = self._unpack(decoded)
|
self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key = self._unpack(decoded)
|
||||||
except (struct.error, binascii.Error):
|
except (struct.error, binascii.Error):
|
||||||
raise SessionDoesNotExist()
|
raise SessionDoesNotExist()
|
||||||
|
|
||||||
def save_session(self, sync=False):
|
def save(self, sync=False):
|
||||||
if not sync:
|
if not sync:
|
||||||
packed = self._pack()
|
packed = self._pack()
|
||||||
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
|
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
|
@ -112,7 +112,8 @@ class Session:
|
|||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
while True:
|
while True:
|
||||||
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
|
self.connection = Connection(self.dc_id, self.client.session_storage.test_mode,
|
||||||
|
self.client.ipv6, self.client.proxy)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.connection.connect()
|
self.connection.connect()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user