From 03b92b3302a9d316d4e693efa8b9d87b0b991fd0 Mon Sep 17 00:00:00 2001 From: bakatrouble Date: Tue, 26 Feb 2019 21:06:30 +0300 Subject: [PATCH] Implement SQLite session storage --- pyrogram/client/client.py | 2 +- pyrogram/client/session_storage/__init__.py | 1 + pyrogram/client/session_storage/json.py | 6 +- .../client/session_storage/sqlite/0001.sql | 21 +++ .../client/session_storage/sqlite/__init__.py | 132 ++++++++++++++++++ 5 files changed, 159 insertions(+), 3 deletions(-) create mode 100644 pyrogram/client/session_storage/sqlite/0001.sql create mode 100644 pyrogram/client/session_storage/sqlite/__init__.py diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index ad755977..5fc805c0 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -57,7 +57,7 @@ from .ext import utils, Syncer, BaseClient from .methods import Methods from .session_storage import ( SessionDoesNotExist, SessionStorage, MemorySessionStorage, JsonSessionStorage, - StringSessionStorage + StringSessionStorage, SQLiteSessionStorage ) log = logging.getLogger(__name__) diff --git a/pyrogram/client/session_storage/__init__.py b/pyrogram/client/session_storage/__init__.py index ad2d8900..adfcf813 100644 --- a/pyrogram/client/session_storage/__init__.py +++ b/pyrogram/client/session_storage/__init__.py @@ -20,3 +20,4 @@ from .abstract import SessionStorage, SessionDoesNotExist from .memory import MemorySessionStorage from .json import JsonSessionStorage from .string import StringSessionStorage +from .sqlite import SQLiteSessionStorage diff --git a/pyrogram/client/session_storage/json.py b/pyrogram/client/session_storage/json.py index aaa6b96f..570e1525 100644 --- a/pyrogram/client/session_storage/json.py +++ b/pyrogram/client/session_storage/json.py @@ -29,6 +29,8 @@ from . import MemorySessionStorage, SessionDoesNotExist log = logging.getLogger(__name__) +EXTENSION = '.session' + class JsonSessionStorage(MemorySessionStorage): def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str): @@ -36,8 +38,8 @@ class JsonSessionStorage(MemorySessionStorage): self._session_name = session_name def _get_file_name(self, name: str): - if not name.endswith('.session'): - name += '.session' + if not name.endswith(EXTENSION): + name += EXTENSION return os.path.join(self._client.workdir, name) def load(self): diff --git a/pyrogram/client/session_storage/sqlite/0001.sql b/pyrogram/client/session_storage/sqlite/0001.sql new file mode 100644 index 00000000..d81e9554 --- /dev/null +++ b/pyrogram/client/session_storage/sqlite/0001.sql @@ -0,0 +1,21 @@ +create table sessions ( + dc_id integer primary key, + test_mode integer, + auth_key blob, + user_id integer, + date integer, + is_bot integer +); + +create table peers_cache ( + id integer primary key, + hash integer, + username text, + phone integer +); + +create table migrations ( + name text primary key +); + +insert into migrations (name) values ('0001'); diff --git a/pyrogram/client/session_storage/sqlite/__init__.py b/pyrogram/client/session_storage/sqlite/__init__.py new file mode 100644 index 00000000..75931109 --- /dev/null +++ b/pyrogram/client/session_storage/sqlite/__init__.py @@ -0,0 +1,132 @@ +# Pyrogram - Telegram MTProto API Client Library for Python +# Copyright (C) 2017-2019 Dan Tès +# +# This file is part of Pyrogram. +# +# Pyrogram is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Pyrogram is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Pyrogram. If not, see . + +import logging +import os +import sqlite3 + +import pyrogram +from ....api import types +from ...ext import utils +from .. import MemorySessionStorage, SessionDoesNotExist + + +log = logging.getLogger(__name__) + +EXTENSION = '.session.sqlite3' +MIGRATIONS = ['0001'] + + +class SQLiteSessionStorage(MemorySessionStorage): + def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str): + super(SQLiteSessionStorage, self).__init__(client) + self._session_name = session_name + self._conn = None # type: sqlite3.Connection + + def _get_file_name(self, name: str): + if not name.endswith(EXTENSION): + name += EXTENSION + return os.path.join(self._client.workdir, name) + + def _apply_migrations(self, new_db=False): + migrations = MIGRATIONS.copy() + if not new_db: + cursor = self._conn.cursor() + cursor.execute('select name from migrations') + for row in cursor.fetchone(): + migrations.remove(row) + for name in migrations: + with open(os.path.join(os.path.dirname(__file__), '{}.sql'.format(name))) as script: + self._conn.executescript(script.read()) + + def load(self): + file_path = self._get_file_name(self._session_name) + log.info('Loading SQLite session from {}'.format(file_path)) + + if os.path.isfile(file_path): + self._conn = sqlite3.connect(file_path) + self._apply_migrations() + else: + self._conn = sqlite3.connect(file_path) + self._apply_migrations(new_db=True) + + cursor = self._conn.cursor() + cursor.execute('select dc_id, test_mode, auth_key, user_id, "date", is_bot from sessions') + row = cursor.fetchone() + if not row: + raise SessionDoesNotExist() + + self._dc_id = row[0] + self._test_mode = bool(row[1]) + self._auth_key = row[2] + self._user_id = row[3] + self._date = row[4] + self._is_bot = bool(row[5]) + + def cache_peer(self, entity): + peer_id = username = phone = access_hash = None + + if isinstance(entity, types.User): + peer_id = entity.id + username = entity.username.lower() if entity.username else None + phone = entity.phone or None + access_hash = entity.access_hash + elif isinstance(entity, (types.Chat, types.ChatForbidden)): + peer_id = -entity.id + # input_peer = types.InputPeerChat(chat_id=entity.id) + elif isinstance(entity, (types.Channel, types.ChannelForbidden)): + peer_id = int('-100' + str(entity.id)) + username = entity.username.lower() if hasattr(entity, 'username') and entity.username else None + access_hash = entity.access_hash + + self._conn.execute('insert or replace into peers_cache values (?, ?, ?, ?)', + (peer_id, access_hash, username, phone)) + + def get_peer_by_id(self, val): + cursor = self._conn.cursor() + cursor.execute('select id, hash from peers_cache where id = ?', (val,)) + row = cursor.fetchone() + if not row: + raise KeyError(val) + return utils.get_input_peer(row[0], row[1]) + + def get_peer_by_username(self, val): + cursor = self._conn.cursor() + cursor.execute('select id, hash from peers_cache where username = ?', (val,)) + row = cursor.fetchone() + if not row: + raise KeyError(val) + return utils.get_input_peer(row[0], row[1]) + + def get_peer_by_phone(self, val): + cursor = self._conn.cursor() + cursor.execute('select id, hash from peers_cache where phone = ?', (val,)) + row = cursor.fetchone() + if not row: + raise KeyError(val) + return utils.get_input_peer(row[0], row[1]) + + def save(self, sync=False): + log.info('Committing SQLite session') + self._conn.execute('delete from sessions') + self._conn.execute('insert into sessions values (?, ?, ?, ?, ?, ?)', + (self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot)) + self._conn.commit() + + def sync_cleanup(self): + pass