mirror of
https://github.com/pyrogram/pyrogram
synced 2025-08-29 13:27:47 +00:00
Merge develop -> asyncio
This commit is contained in:
commit
826885a821
@ -40,7 +40,6 @@ Welcome to Pyrogram
|
||||
topics/more-on-updates
|
||||
topics/config-file
|
||||
topics/smart-plugins
|
||||
topics/auto-auth
|
||||
topics/session-settings
|
||||
topics/tgcrypto
|
||||
topics/storage-engines
|
||||
|
@ -1,68 +0,0 @@
|
||||
Auto Authorization
|
||||
==================
|
||||
|
||||
Manually writing phone number, phone code and password on the terminal every time you want to login can be tedious.
|
||||
Pyrogram is able to automate both **Log In** and **Sign Up** processes, all you need to do is pass the relevant
|
||||
parameters when creating a new :class:`~pyrogram.Client`.
|
||||
|
||||
.. note:: If you omit any of the optional parameter required for the authorization, Pyrogram will ask you to
|
||||
manually write it. For instance, if you don't want to set a ``last_name`` when creating a new account you
|
||||
have to explicitly pass an empty string ""; the default value (None) will trigger the input() call.
|
||||
|
||||
Log In
|
||||
-------
|
||||
|
||||
To automate the **Log In** process, pass your ``phone_number`` and ``password`` (if you have one) in the Client parameters.
|
||||
If you want to retrieve the phone code programmatically, pass a callback function in the ``phone_code`` field — this
|
||||
function accepts a single positional argument (phone_number) and must return the correct phone code (e.g., "12345")
|
||||
— otherwise, ignore this parameter, Pyrogram will ask you to input the phone code manually.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pyrogram import Client
|
||||
|
||||
def phone_code_callback(phone_number):
|
||||
code = ... # Get your code programmatically
|
||||
return code # e.g., "12345"
|
||||
|
||||
|
||||
app = Client(
|
||||
session_name="example",
|
||||
phone_number="39**********",
|
||||
phone_code=phone_code_callback, # Note the missing parentheses
|
||||
password="password" # (if you have one)
|
||||
)
|
||||
|
||||
with app:
|
||||
print(app.get_me())
|
||||
|
||||
Sign Up
|
||||
-------
|
||||
|
||||
To automate the **Sign Up** process (i.e., automatically create a new Telegram account), simply fill **both**
|
||||
``first_name`` and ``last_name`` fields alongside the other parameters; they will be used to automatically create a new
|
||||
Telegram account in case the phone number you passed is not registered yet.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pyrogram import Client
|
||||
|
||||
def phone_code_callback(phone_number):
|
||||
code = ... # Get your code programmatically
|
||||
return code # e.g., "12345"
|
||||
|
||||
|
||||
app = Client(
|
||||
session_name="example",
|
||||
phone_number="39**********",
|
||||
phone_code=phone_code_callback, # Note the missing parentheses
|
||||
first_name="Pyrogram",
|
||||
last_name="" # Can be an empty string
|
||||
)
|
||||
|
||||
with app:
|
||||
print(app.get_me())
|
@ -266,13 +266,13 @@ class Client(Methods, BaseClient):
|
||||
self.load_config()
|
||||
await self.load_session()
|
||||
|
||||
self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
|
||||
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
|
||||
|
||||
await self.session.start()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
return bool(self.storage.user_id)
|
||||
return bool(self.storage.user_id())
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect the client from Telegram servers.
|
||||
@ -402,9 +402,9 @@ class Client(Methods, BaseClient):
|
||||
except (PhoneMigrate, NetworkMigrate) as e:
|
||||
await self.session.stop()
|
||||
|
||||
self.storage.dc_id = e.x
|
||||
self.storage.auth_key = await Auth(self, self.storage.dc_id).create()
|
||||
self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
|
||||
self.storage.dc_id(e.x)
|
||||
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
|
||||
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
|
||||
|
||||
await self.session.start()
|
||||
else:
|
||||
@ -480,8 +480,8 @@ class Client(Methods, BaseClient):
|
||||
|
||||
return False
|
||||
else:
|
||||
self.storage.user_id = r.user.id
|
||||
self.storage.is_bot = False
|
||||
self.storage.user_id(r.user.id)
|
||||
self.storage.is_bot(False)
|
||||
|
||||
return User._parse(self, r.user)
|
||||
|
||||
@ -518,8 +518,8 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
)
|
||||
|
||||
self.storage.user_id = r.user.id
|
||||
self.storage.is_bot = False
|
||||
self.storage.user_id(r.user.id)
|
||||
self.storage.is_bot(False)
|
||||
|
||||
return User._parse(self, r.user)
|
||||
|
||||
@ -549,14 +549,14 @@ class Client(Methods, BaseClient):
|
||||
except UserMigrate as e:
|
||||
await self.session.stop()
|
||||
|
||||
self.storage.dc_id = e.x
|
||||
self.storage.auth_key = await Auth(self, self.storage.dc_id).create()
|
||||
self.session = Session(self, self.storage.dc_id, self.storage.auth_key)
|
||||
self.storage.dc_id(e.x)
|
||||
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
|
||||
self.session = Session(self, self.storage.dc_id(), self.storage.auth_key())
|
||||
|
||||
await self.session.start()
|
||||
else:
|
||||
self.storage.user_id = r.user.id
|
||||
self.storage.is_bot = True
|
||||
self.storage.user_id(r.user.id)
|
||||
self.storage.is_bot(True)
|
||||
|
||||
return User._parse(self, r.user)
|
||||
|
||||
@ -590,8 +590,8 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
)
|
||||
|
||||
self.storage.user_id = r.user.id
|
||||
self.storage.is_bot = False
|
||||
self.storage.user_id(r.user.id)
|
||||
self.storage.is_bot(False)
|
||||
|
||||
return User._parse(self, r.user)
|
||||
|
||||
@ -627,8 +627,8 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
)
|
||||
|
||||
self.storage.user_id = r.user.id
|
||||
self.storage.is_bot = False
|
||||
self.storage.user_id(r.user.id)
|
||||
self.storage.is_bot(False)
|
||||
|
||||
return User._parse(self, r.user)
|
||||
|
||||
@ -784,7 +784,7 @@ class Client(Methods, BaseClient):
|
||||
async def log_out(self):
|
||||
"""Log out from Telegram and delete the *\\*.session* file.
|
||||
|
||||
When you log out, the current client is stopped and the storage session destroyed.
|
||||
When you log out, the current client is stopped and the storage session deleted.
|
||||
No more API calls can be made until you start the client and re-authorize again.
|
||||
|
||||
Returns:
|
||||
@ -798,7 +798,7 @@ class Client(Methods, BaseClient):
|
||||
"""
|
||||
await self.send(functions.auth.LogOut())
|
||||
await self.stop()
|
||||
self.storage.destroy()
|
||||
self.storage.delete()
|
||||
|
||||
return True
|
||||
|
||||
@ -833,7 +833,7 @@ class Client(Methods, BaseClient):
|
||||
if not is_authorized:
|
||||
await self.authorize()
|
||||
|
||||
if not self.storage.is_bot and self.takeout:
|
||||
if not self.storage.is_bot() and self.takeout:
|
||||
self.takeout_id = (await self.send(functions.account.InitTakeoutSession())).id
|
||||
log.warning("Takeout session {} initiated".format(self.takeout_id))
|
||||
|
||||
@ -1176,41 +1176,24 @@ class Client(Methods, BaseClient):
|
||||
|
||||
self.parse_mode = parse_mode
|
||||
|
||||
def fetch_peers(
|
||||
self,
|
||||
peers: List[
|
||||
Union[
|
||||
types.User,
|
||||
types.Chat, types.ChatForbidden,
|
||||
types.Channel, types.ChannelForbidden
|
||||
]
|
||||
]
|
||||
) -> bool:
|
||||
def fetch_peers(self, peers: List[Union[types.User, types.Chat, types.Channel]]) -> bool:
|
||||
is_min = False
|
||||
parsed_peers = []
|
||||
|
||||
for peer in peers:
|
||||
if getattr(peer, "min", False):
|
||||
is_min = True
|
||||
continue
|
||||
|
||||
username = None
|
||||
phone_number = None
|
||||
|
||||
if isinstance(peer, types.User):
|
||||
peer_id = peer.id
|
||||
access_hash = peer.access_hash
|
||||
|
||||
username = peer.username
|
||||
username = (peer.username or "").lower() or None
|
||||
phone_number = peer.phone
|
||||
|
||||
if peer.bot:
|
||||
peer_type = "bot"
|
||||
else:
|
||||
peer_type = "user"
|
||||
|
||||
if access_hash is None:
|
||||
is_min = True
|
||||
continue
|
||||
|
||||
if username is not None:
|
||||
username = username.lower()
|
||||
peer_type = "bot" if peer.bot else "user"
|
||||
elif isinstance(peer, (types.Chat, types.ChatForbidden)):
|
||||
peer_id = -peer.id
|
||||
access_hash = 0
|
||||
@ -1218,20 +1201,8 @@ class Client(Methods, BaseClient):
|
||||
elif isinstance(peer, (types.Channel, types.ChannelForbidden)):
|
||||
peer_id = utils.get_channel_id(peer.id)
|
||||
access_hash = peer.access_hash
|
||||
|
||||
username = getattr(peer, "username", None)
|
||||
|
||||
if peer.broadcast:
|
||||
peer_type = "channel"
|
||||
else:
|
||||
peer_type = "supergroup"
|
||||
|
||||
if access_hash is None:
|
||||
is_min = True
|
||||
continue
|
||||
|
||||
if username is not None:
|
||||
username = username.lower()
|
||||
username = (getattr(peer, "username", None) or "").lower() or None
|
||||
peer_type = "channel" if peer.broadcast else "supergroup"
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -1494,20 +1465,20 @@ class Client(Methods, BaseClient):
|
||||
self.storage.open()
|
||||
|
||||
session_empty = any([
|
||||
self.storage.test_mode is None,
|
||||
self.storage.auth_key is None,
|
||||
self.storage.user_id is None,
|
||||
self.storage.is_bot is None
|
||||
self.storage.test_mode() is None,
|
||||
self.storage.auth_key() is None,
|
||||
self.storage.user_id() is None,
|
||||
self.storage.is_bot() is None
|
||||
])
|
||||
|
||||
if session_empty:
|
||||
self.storage.dc_id = 2
|
||||
self.storage.date = 0
|
||||
self.storage.dc_id(2)
|
||||
self.storage.date(0)
|
||||
|
||||
self.storage.test_mode = self.test_mode
|
||||
self.storage.auth_key = await Auth(self, self.storage.dc_id).create()
|
||||
self.storage.user_id = None
|
||||
self.storage.is_bot = None
|
||||
self.storage.test_mode(self.test_mode)
|
||||
self.storage.auth_key(await Auth(self, self.storage.dc_id()).create())
|
||||
self.storage.user_id(None)
|
||||
self.storage.is_bot(None)
|
||||
|
||||
def load_plugins(self):
|
||||
if self.plugins:
|
||||
@ -1715,7 +1686,7 @@ class Client(Methods, BaseClient):
|
||||
except KeyError:
|
||||
raise PeerIdInvalid
|
||||
|
||||
peer_type = utils.get_type(peer_id)
|
||||
peer_type = utils.get_peer_type(peer_id)
|
||||
|
||||
if peer_type == "user":
|
||||
self.fetch_peers(
|
||||
@ -1836,7 +1807,7 @@ class Client(Methods, BaseClient):
|
||||
is_missing_part = file_id is not None
|
||||
file_id = file_id or self.rnd_id()
|
||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||
pool = [Session(self, self.storage.dc_id, self.storage.auth_key, is_media=True) for _ in range(pool_size)]
|
||||
pool = [Session(self, self.storage.dc_id(), self.storage.auth_key(), is_media=True) for _ in range(pool_size)]
|
||||
workers = [asyncio.ensure_future(worker(session)) for session in pool for _ in range(workers_count)]
|
||||
queue = asyncio.Queue(16)
|
||||
|
||||
@ -1926,7 +1897,7 @@ class Client(Methods, BaseClient):
|
||||
session = self.media_sessions.get(dc_id, None)
|
||||
|
||||
if session is None:
|
||||
if dc_id != self.storage.dc_id:
|
||||
if dc_id != self.storage.dc_id():
|
||||
session = Session(self, dc_id, await Auth(self, dc_id).create(), is_media=True)
|
||||
await session.start()
|
||||
|
||||
@ -1952,7 +1923,7 @@ class Client(Methods, BaseClient):
|
||||
await session.stop()
|
||||
raise AuthBytesInvalid
|
||||
else:
|
||||
session = Session(self, dc_id, self.storage.auth_key, is_media=True)
|
||||
session = Session(self, dc_id, self.storage.auth_key(), is_media=True)
|
||||
await session.start()
|
||||
|
||||
self.media_sessions[dc_id] = session
|
||||
|
@ -227,7 +227,7 @@ def get_peer_id(peer: Union[PeerUser, PeerChat, PeerChannel]) -> int:
|
||||
raise ValueError("Peer type invalid: {}".format(peer))
|
||||
|
||||
|
||||
def get_type(peer_id: int) -> str:
|
||||
def get_peer_type(peer_id: int) -> str:
|
||||
if peer_id < 0:
|
||||
if MIN_CHAT_ID <= peer_id:
|
||||
return "chat"
|
||||
|
@ -22,34 +22,29 @@ import logging
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
from .memory_storage import MemoryStorage
|
||||
from .sqlite_storage import SQLiteStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStorage(MemoryStorage):
|
||||
class FileStorage(SQLiteStorage):
|
||||
FILE_EXTENSION = ".session"
|
||||
|
||||
def __init__(self, name: str, workdir: Path):
|
||||
super().__init__(name)
|
||||
|
||||
self.workdir = workdir
|
||||
self.database = workdir / (self.name + self.FILE_EXTENSION)
|
||||
self.conn = None # type: sqlite3.Connection
|
||||
self.lock = Lock()
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def migrate_from_json(self, session_json: dict):
|
||||
self.open()
|
||||
|
||||
self.dc_id = session_json["dc_id"]
|
||||
self.test_mode = session_json["test_mode"]
|
||||
self.auth_key = base64.b64decode("".join(session_json["auth_key"]))
|
||||
self.user_id = session_json["user_id"]
|
||||
self.date = session_json.get("date", 0)
|
||||
self.is_bot = session_json.get("is_bot", False)
|
||||
self.dc_id(session_json["dc_id"])
|
||||
self.test_mode(session_json["test_mode"])
|
||||
self.auth_key(base64.b64decode("".join(session_json["auth_key"])))
|
||||
self.user_id(session_json["user_id"])
|
||||
self.date(session_json.get("date", 0))
|
||||
self.is_bot(session_json.get("is_bot", False))
|
||||
|
||||
peers_by_id = session_json.get("peers_by_id", {})
|
||||
peers_by_phone = session_json.get("peers_by_phone", {})
|
||||
@ -72,6 +67,17 @@ class FileStorage(MemoryStorage):
|
||||
# noinspection PyTypeChecker
|
||||
self.update_peers(peers.values())
|
||||
|
||||
def update(self):
|
||||
version = self.version()
|
||||
|
||||
if version == 1:
|
||||
with self.lock, self.conn:
|
||||
self.conn.execute("DELETE FROM peers")
|
||||
|
||||
version += 1
|
||||
|
||||
self.version(version)
|
||||
|
||||
def open(self):
|
||||
path = self.database
|
||||
file_exists = path.is_file()
|
||||
@ -98,14 +104,12 @@ class FileStorage(MemoryStorage):
|
||||
if Path(path.name + ".OLD").is_file():
|
||||
log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name))
|
||||
|
||||
self.conn = sqlite3.connect(
|
||||
str(path),
|
||||
timeout=1,
|
||||
check_same_thread=False
|
||||
)
|
||||
self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
|
||||
|
||||
if not file_exists:
|
||||
self.create()
|
||||
else:
|
||||
self.update()
|
||||
|
||||
with self.conn:
|
||||
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
|
||||
@ -113,5 +117,5 @@ class FileStorage(MemoryStorage):
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
def destroy(self):
|
||||
def delete(self):
|
||||
os.remove(self.database)
|
||||
|
@ -17,226 +17,37 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import List, Tuple
|
||||
|
||||
from pyrogram.api import types
|
||||
from pyrogram.client.storage.storage import Storage
|
||||
from .sqlite_storage import SQLiteStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStorage(Storage):
|
||||
SCHEMA_VERSION = 1
|
||||
USERNAME_TTL = 8 * 60 * 60
|
||||
SESSION_STRING_FMT = ">B?256sI?"
|
||||
SESSION_STRING_SIZE = 351
|
||||
|
||||
class MemoryStorage(SQLiteStorage):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
self.conn = None # type: sqlite3.Connection
|
||||
self.lock = Lock()
|
||||
|
||||
def create(self):
|
||||
with self.lock, self.conn:
|
||||
with open(str(Path(__file__).parent / "schema.sql"), "r") as schema:
|
||||
self.conn.executescript(schema.read())
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO version VALUES (?)",
|
||||
(self.SCHEMA_VERSION,)
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(1, None, None, 0, None, None)
|
||||
)
|
||||
|
||||
def _import_session_string(self, session_string: str):
|
||||
decoded = base64.urlsafe_b64decode(session_string + "=" * (-len(session_string) % 4))
|
||||
return struct.unpack(self.SESSION_STRING_FMT, decoded)
|
||||
|
||||
def export_session_string(self):
|
||||
packed = struct.pack(
|
||||
self.SESSION_STRING_FMT,
|
||||
self.dc_id,
|
||||
self.test_mode,
|
||||
self.auth_key,
|
||||
self.user_id,
|
||||
self.is_bot
|
||||
)
|
||||
|
||||
return base64.urlsafe_b64encode(packed).decode().rstrip("=")
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def open(self):
|
||||
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self.create()
|
||||
|
||||
if self.name != ":memory:":
|
||||
imported_session_string = self._import_session_string(self.name)
|
||||
dc_id, test_mode, auth_key, user_id, is_bot = struct.unpack(
|
||||
self.SESSION_STRING_FORMAT,
|
||||
base64.urlsafe_b64decode(
|
||||
self.name + "=" * (-len(self.name) % 4)
|
||||
)
|
||||
)
|
||||
|
||||
self.dc_id, self.test_mode, self.auth_key, self.user_id, self.is_bot = imported_session_string
|
||||
self.date = 0
|
||||
self.dc_id(dc_id)
|
||||
self.test_mode(test_mode)
|
||||
self.auth_key(auth_key)
|
||||
self.user_id(user_id)
|
||||
self.is_bot(is_bot)
|
||||
self.date(0)
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def save(self):
|
||||
self.date = int(time.time())
|
||||
|
||||
with self.lock:
|
||||
self.conn.commit()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
self.conn.close()
|
||||
|
||||
def destroy(self):
|
||||
def delete(self):
|
||||
pass
|
||||
|
||||
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||
with self.lock:
|
||||
self.conn.executemany(
|
||||
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
peers
|
||||
)
|
||||
|
||||
def clear_peers(self):
|
||||
with self.lock, self.conn:
|
||||
self.conn.execute(
|
||||
"DELETE FROM peers"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||
if peer_type in ["user", "bot"]:
|
||||
return types.InputPeerUser(
|
||||
user_id=peer_id,
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
if peer_type == "group":
|
||||
return types.InputPeerChat(
|
||||
chat_id=-peer_id
|
||||
)
|
||||
|
||||
if peer_type in ["channel", "supergroup"]:
|
||||
return types.InputPeerChannel(
|
||||
channel_id=int(str(peer_id)[4:]),
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
raise ValueError("Invalid peer type: {}".format(peer_type))
|
||||
|
||||
def get_peer_by_id(self, peer_id: int):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type FROM peers WHERE id = ?",
|
||||
(peer_id,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("ID not found: {}".format(peer_id))
|
||||
|
||||
return self._get_input_peer(*r)
|
||||
|
||||
def get_peer_by_username(self, username: str):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?",
|
||||
(username,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("Username not found: {}".format(username))
|
||||
|
||||
if abs(time.time() - r[3]) > self.USERNAME_TTL:
|
||||
raise KeyError("Username expired: {}".format(username))
|
||||
|
||||
return self._get_input_peer(*r[:3])
|
||||
|
||||
def get_peer_by_phone_number(self, phone_number: str):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
|
||||
(phone_number,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("Phone number not found: {}".format(phone_number))
|
||||
|
||||
return self._get_input_peer(*r)
|
||||
|
||||
@property
|
||||
def peers_count(self):
|
||||
return self.conn.execute(
|
||||
"SELECT COUNT(*) FROM peers"
|
||||
).fetchone()[0]
|
||||
|
||||
def _get(self):
|
||||
attr = inspect.stack()[1].function
|
||||
|
||||
return self.conn.execute(
|
||||
"SELECT {} FROM sessions".format(attr)
|
||||
).fetchone()[0]
|
||||
|
||||
def _set(self, value):
|
||||
attr = inspect.stack()[1].function
|
||||
|
||||
with self.lock, self.conn:
|
||||
self.conn.execute(
|
||||
"UPDATE sessions SET {} = ?".format(attr),
|
||||
(value,)
|
||||
)
|
||||
|
||||
@property
|
||||
def dc_id(self):
|
||||
return self._get()
|
||||
|
||||
@dc_id.setter
|
||||
def dc_id(self, value):
|
||||
self._set(value)
|
||||
|
||||
@property
|
||||
def test_mode(self):
|
||||
return self._get()
|
||||
|
||||
@test_mode.setter
|
||||
def test_mode(self, value):
|
||||
self._set(value)
|
||||
|
||||
@property
|
||||
def auth_key(self):
|
||||
return self._get()
|
||||
|
||||
@auth_key.setter
|
||||
def auth_key(self, value):
|
||||
self._set(value)
|
||||
|
||||
@property
|
||||
def date(self):
|
||||
return self._get()
|
||||
|
||||
@date.setter
|
||||
def date(self, value):
|
||||
self._set(value)
|
||||
|
||||
@property
|
||||
def user_id(self):
|
||||
return self._get()
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, value):
|
||||
self._set(value)
|
||||
|
||||
@property
|
||||
def is_bot(self):
|
||||
return self._get()
|
||||
|
||||
@is_bot.setter
|
||||
def is_bot(self, value):
|
||||
self._set(value)
|
||||
|
184
pyrogram/client/storage/sqlite_storage.py
Normal file
184
pyrogram/client/storage/sqlite_storage.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Pyrogram - Telegram MTProto API Client Library for Python
|
||||
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
|
||||
#
|
||||
# This file is part of Pyrogram.
|
||||
#
|
||||
# Pyrogram is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Pyrogram is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import inspect
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from pyrogram.api import types
|
||||
from pyrogram.client.ext import utils
|
||||
from .storage import Storage
|
||||
|
||||
|
||||
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||
if peer_type in ["user", "bot"]:
|
||||
return types.InputPeerUser(
|
||||
user_id=peer_id,
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
if peer_type == "group":
|
||||
return types.InputPeerChat(
|
||||
chat_id=-peer_id
|
||||
)
|
||||
|
||||
if peer_type in ["channel", "supergroup"]:
|
||||
return types.InputPeerChannel(
|
||||
channel_id=utils.get_channel_id(peer_id),
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
raise ValueError("Invalid peer type: {}".format(peer_type))
|
||||
|
||||
|
||||
class SQLiteStorage(Storage):
|
||||
VERSION = 2
|
||||
USERNAME_TTL = 8 * 60 * 60
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
||||
self.conn = None # type: sqlite3.Connection
|
||||
self.lock = Lock()
|
||||
|
||||
def create(self):
|
||||
with self.lock, self.conn:
|
||||
with open(str(Path(__file__).parent / "schema.sql"), "r") as schema:
|
||||
self.conn.executescript(schema.read())
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO version VALUES (?)",
|
||||
(self.VERSION,)
|
||||
)
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(2, None, None, 0, None, None)
|
||||
)
|
||||
|
||||
def open(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
self.date(int(time.time()))
|
||||
|
||||
with self.lock:
|
||||
self.conn.commit()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
self.conn.close()
|
||||
|
||||
def delete(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||
with self.lock:
|
||||
self.conn.executemany(
|
||||
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
peers
|
||||
)
|
||||
|
||||
def get_peer_by_id(self, peer_id: int):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type FROM peers WHERE id = ?",
|
||||
(peer_id,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("ID not found: {}".format(peer_id))
|
||||
|
||||
return get_input_peer(*r)
|
||||
|
||||
def get_peer_by_username(self, username: str):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?",
|
||||
(username,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("Username not found: {}".format(username))
|
||||
|
||||
if abs(time.time() - r[3]) > self.USERNAME_TTL:
|
||||
raise KeyError("Username expired: {}".format(username))
|
||||
|
||||
return get_input_peer(*r[:3])
|
||||
|
||||
def get_peer_by_phone_number(self, phone_number: str):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
|
||||
(phone_number,)
|
||||
).fetchone()
|
||||
|
||||
if r is None:
|
||||
raise KeyError("Phone number not found: {}".format(phone_number))
|
||||
|
||||
return get_input_peer(*r)
|
||||
|
||||
def _get(self):
|
||||
attr = inspect.stack()[2].function
|
||||
|
||||
return self.conn.execute(
|
||||
"SELECT {} FROM sessions".format(attr)
|
||||
).fetchone()[0]
|
||||
|
||||
def _set(self, value: Any):
|
||||
attr = inspect.stack()[2].function
|
||||
|
||||
with self.lock, self.conn:
|
||||
self.conn.execute(
|
||||
"UPDATE sessions SET {} = ?".format(attr),
|
||||
(value,)
|
||||
)
|
||||
|
||||
def _accessor(self, value: Any = object):
|
||||
return self._get() if value == object else self._set(value)
|
||||
|
||||
def dc_id(self, value: int = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def test_mode(self, value: bool = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def auth_key(self, value: bytes = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def date(self, value: int = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def user_id(self, value: int = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def is_bot(self, value: bool = object):
|
||||
return self._accessor(value)
|
||||
|
||||
def version(self, value: int = object):
|
||||
if value == object:
|
||||
return self.conn.execute(
|
||||
"SELECT number FROM version"
|
||||
).fetchone()[0]
|
||||
else:
|
||||
with self.lock, self.conn:
|
||||
self.conn.execute(
|
||||
"UPDATE version SET number = ?",
|
||||
(value,)
|
||||
)
|
@ -16,8 +16,15 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class Storage:
|
||||
SESSION_STRING_FORMAT = ">B?256sI?"
|
||||
SESSION_STRING_SIZE = 351
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
@ -30,72 +37,47 @@ class Storage:
|
||||
def close(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
def delete(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_peers(self, peers):
|
||||
def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_peer_by_id(self, peer_id):
|
||||
def get_peer_by_id(self, peer_id: int):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_peer_by_username(self, username):
|
||||
def get_peer_by_username(self, username: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_peer_by_phone_number(self, phone_number):
|
||||
def get_peer_by_phone_number(self, phone_number: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def dc_id(self, value: int = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_mode(self, value: bool = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def auth_key(self, value: bytes = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def date(self, value: int = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def user_id(self, value: int = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_bot(self, value: bool = object):
|
||||
raise NotImplementedError
|
||||
|
||||
def export_session_string(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def peers_count(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def dc_id(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@dc_id.setter
|
||||
def dc_id(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def test_mode(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@test_mode.setter
|
||||
def test_mode(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def auth_key(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@auth_key.setter
|
||||
def auth_key(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def date(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@date.setter
|
||||
def date(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def user_id(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@user_id.setter
|
||||
def user_id(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_bot(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@is_bot.setter
|
||||
def is_bot(self, value):
|
||||
raise NotImplementedError
|
||||
return base64.urlsafe_b64encode(
|
||||
struct.pack(
|
||||
self.SESSION_STRING_FORMAT,
|
||||
self.dc_id(),
|
||||
self.test_mode(),
|
||||
self.auth_key(),
|
||||
self.user_id(),
|
||||
self.is_bot()
|
||||
)
|
||||
).decode().rstrip("=")
|
||||
|
@ -162,7 +162,7 @@ class Chat(Object):
|
||||
username=user.username,
|
||||
first_name=user.first_name,
|
||||
last_name=user.last_name,
|
||||
photo=ChatPhoto._parse(client, user.photo, peer_id),
|
||||
photo=ChatPhoto._parse(client, user.photo, peer_id, user.access_hash),
|
||||
restrictions=pyrogram.List([Restriction._parse(r) for r in user.restriction_reason]) or None,
|
||||
client=client
|
||||
)
|
||||
@ -175,7 +175,7 @@ class Chat(Object):
|
||||
id=peer_id,
|
||||
type="group",
|
||||
title=chat.title,
|
||||
photo=ChatPhoto._parse(client, getattr(chat, "photo", None), peer_id),
|
||||
photo=ChatPhoto._parse(client, getattr(chat, "photo", None), peer_id, 0),
|
||||
permissions=ChatPermissions._parse(getattr(chat, "default_banned_rights", None)),
|
||||
members_count=getattr(chat, "participants_count", None),
|
||||
client=client
|
||||
@ -194,7 +194,7 @@ class Chat(Object):
|
||||
is_scam=getattr(channel, "scam", None),
|
||||
title=channel.title,
|
||||
username=getattr(channel, "username", None),
|
||||
photo=ChatPhoto._parse(client, getattr(channel, "photo", None), peer_id),
|
||||
photo=ChatPhoto._parse(client, getattr(channel, "photo", None), peer_id, channel.access_hash),
|
||||
restrictions=pyrogram.List([Restriction._parse(r) for r in restriction_reason]) or None,
|
||||
permissions=ChatPermissions._parse(getattr(channel, "default_banned_rights", None)),
|
||||
members_count=getattr(channel, "participants_count", None),
|
||||
|
@ -20,6 +20,7 @@ from struct import pack
|
||||
|
||||
import pyrogram
|
||||
from pyrogram.api import types
|
||||
from pyrogram.client.ext import utils
|
||||
from ..object import Object
|
||||
from ...ext.utils import encode
|
||||
|
||||
@ -50,7 +51,7 @@ class ChatPhoto(Object):
|
||||
self.big_file_id = big_file_id
|
||||
|
||||
@staticmethod
|
||||
def _parse(client, chat_photo: types.UserProfilePhoto or types.ChatPhoto, peer_id: int):
|
||||
def _parse(client, chat_photo: types.UserProfilePhoto or types.ChatPhoto, peer_id: int, peer_access_hash: int):
|
||||
if not isinstance(chat_photo, (types.UserProfilePhoto, types.ChatPhoto)):
|
||||
return None
|
||||
|
||||
@ -58,24 +59,14 @@ class ChatPhoto(Object):
|
||||
loc_small = chat_photo.photo_small
|
||||
loc_big = chat_photo.photo_big
|
||||
|
||||
try:
|
||||
# We just want a local storage lookup by id, whose method is not async.
|
||||
# Otherwise we have to turn this _parse method async and also all the other methods that use this one.
|
||||
peer = client.storage.get_peer_by_id(peer_id)
|
||||
except KeyError:
|
||||
return None
|
||||
peer_type = utils.get_peer_type(peer_id)
|
||||
|
||||
if isinstance(peer, types.InputPeerUser):
|
||||
peer_id = peer.user_id
|
||||
peer_access_hash = peer.access_hash
|
||||
if peer_type == "user":
|
||||
x = 0
|
||||
elif isinstance(peer, types.InputPeerChat):
|
||||
peer_id = -peer.chat_id
|
||||
peer_access_hash = 0
|
||||
elif peer_type == "chat":
|
||||
x = -1
|
||||
else:
|
||||
peer_id += 1000727379968
|
||||
peer_access_hash = peer.access_hash
|
||||
x = -234
|
||||
|
||||
return ChatPhoto(
|
||||
|
@ -187,7 +187,7 @@ class User(Object, Update):
|
||||
language_code=user.lang_code,
|
||||
dc_id=getattr(user.photo, "dc_id", None),
|
||||
phone_number=user.phone,
|
||||
photo=ChatPhoto._parse(client, user.photo, user.id),
|
||||
photo=ChatPhoto._parse(client, user.photo, user.id, user.access_hash),
|
||||
restrictions=pyrogram.List([Restriction._parse(r) for r in user.restriction_reason]) or None,
|
||||
client=client
|
||||
)
|
||||
|
@ -38,7 +38,7 @@ class Auth:
|
||||
|
||||
def __init__(self, client: "pyrogram.Client", dc_id: int):
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = client.storage.test_mode
|
||||
self.test_mode = client.storage.test_mode()
|
||||
self.ipv6 = client.ipv6
|
||||
self.proxy = client.proxy
|
||||
|
||||
|
@ -114,7 +114,7 @@ class Session:
|
||||
while True:
|
||||
self.connection = Connection(
|
||||
self.dc_id,
|
||||
self.client.storage.test_mode,
|
||||
self.client.storage.test_mode(),
|
||||
self.client.ipv6,
|
||||
self.client.proxy
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user