diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 3232e0b0..9e28d66b 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -51,22 +51,15 @@ from pyrogram.session import Auth, Session from pyrogram.session.internals import MsgId from . import message_parser from .dispatcher import Dispatcher +from . import utils from .input_media import InputMedia from .style import Markdown, HTML +from .syncer import Syncer from .utils import decode log = logging.getLogger(__name__) -class Proxy: - def __init__(self, enabled: bool, hostname: str, port: int, username: str, password: str): - self.enabled = enabled - self.hostname = hostname - self.port = port - self.username = username - self.password = password - - class Client: """This class represents a Client, the main mean for interacting with Telegram. It exposes bot-like methods for an easy access to the API as well as a simple way to @@ -133,6 +126,7 @@ class Client: DIALOGS_AT_ONCE = 100 UPDATES_WORKERS = 1 DOWNLOAD_WORKERS = 1 + OFFLINE_SLEEP = 300 MEDIA_TYPE_ID = { 0: "Thumbnail", @@ -150,7 +144,7 @@ class Client: session_name: str, api_id: int or str = None, api_hash: str = None, - proxy: dict or Proxy = None, + proxy: dict = None, test_mode: bool = False, phone_number: str = None, phone_code: str or callable = None, @@ -179,6 +173,7 @@ class Client: self.dc_id = None self.auth_key = None self.user_id = None + self.date = None self.rnd_id = MsgId @@ -269,7 +264,7 @@ class Client: self.session_name = self.session_name.split(":")[0] self.load_config() - self.load_session(self.session_name) + self.load_session() self.session = Session( self.dc_id, @@ -292,8 +287,14 @@ class Client: self.save_session() if self.token is None: - self.get_dialogs() - self.get_contacts() + now = time.time() + + if abs(now - self.date) > Client.OFFLINE_SLEEP: + self.get_dialogs() + self.get_contacts() + else: + self.send(functions.messages.GetPinnedDialogs()) + self.get_dialogs_chunk(0) else: self.send(functions.updates.GetState()) @@ -306,6 +307,7 @@ class Client: self.dispatcher.start() mimetypes.init() + Syncer.add(self) def stop(self): """Use this method to manually stop the Client. @@ -325,6 +327,8 @@ class Client: self.dispatcher.stop() + Syncer.remove(self) + def authorize_bot(self): try: r = self.send( @@ -835,35 +839,47 @@ class Client: "More info: https://docs.pyrogram.ml/start/ProjectSetup#configuration" ) - if self.proxy is not None: - self.proxy = Proxy( - enabled=True, - hostname=self.proxy["hostname"], - port=int(self.proxy["port"]), - username=self.proxy.get("username", None), - password=self.proxy.get("password", None) - ) - elif parser.has_section("proxy"): - self.proxy = Proxy( - enabled=parser.getboolean("proxy", "enabled"), - hostname=parser.get("proxy", "hostname"), - port=parser.getint("proxy", "port"), - username=parser.get("proxy", "username", fallback=None) or None, - password=parser.get("proxy", "password", fallback=None) or None - ) + if self.proxy: + pass + else: + self.proxy = {} - def load_session(self, session_name): + if parser.has_section("proxy"): + self.proxy["enabled"] = parser.getboolean("proxy", "enabled") + self.proxy["hostname"] = parser.get("proxy", "hostname") + self.proxy["port"] = parser.getint("proxy", "port") + self.proxy["username"] = parser.get("proxy", "username", fallback=None) or None + self.proxy["password"] = parser.get("proxy", "password", fallback=None) or None + + def load_session(self): try: - with open("{}.session".format(session_name), encoding="utf-8") as f: + with open("{}.session".format(self.session_name), encoding="utf-8") as f: s = json.load(f) except FileNotFoundError: self.dc_id = 1 + self.date = 0 self.auth_key = Auth(self.dc_id, self.test_mode, self.proxy).create() else: self.dc_id = s["dc_id"] self.test_mode = s["test_mode"] self.auth_key = base64.b64decode("".join(s["auth_key"])) self.user_id = s["user_id"] + self.date = s.get("date", 0) + + for k, v in s.get("peers_by_id", {}).items(): + self.peers_by_id[int(k)] = utils.get_input_peer(int(k), v) + + for k, v in s.get("peers_by_username", {}).items(): + peer = self.peers_by_id.get(v, None) + + if peer: + self.peers_by_username[k] = peer + + for k, v in s.get("peers_by_phone", {}).items(): + peer = self.peers_by_id.get(v, None) + + if peer: + self.peers_by_phone[k] = peer def save_session(self): auth_key = base64.b64encode(self.auth_key).decode() @@ -876,58 +892,40 @@ class Client: test_mode=self.test_mode, auth_key=auth_key, user_id=self.user_id, + date=self.date ), f, indent=4 ) - def get_dialogs(self): - def parse_dialogs(d): - for m in reversed(d.messages): - if isinstance(m, types.MessageEmpty): - continue - else: - return m.date - else: - return 0 - - pinned_dialogs = self.send(functions.messages.GetPinnedDialogs()) - parse_dialogs(pinned_dialogs) - - dialogs = self.send( + def get_dialogs_chunk(self, offset_date): + r = self.send( functions.messages.GetDialogs( - 0, 0, types.InputPeerEmpty(), + offset_date, 0, types.InputPeerEmpty(), self.DIALOGS_AT_ONCE, True ) ) + log.info("Total peers: {}".format(len(self.peers_by_id))) - offset_date = parse_dialogs(dialogs) - log.info("Entities count: {}".format(len(self.peers_by_id))) + return r + + def get_dialogs(self): + self.send(functions.messages.GetPinnedDialogs()) + + dialogs = self.get_dialogs_chunk(0) + offset_date = utils.get_offset_date(dialogs) while len(dialogs.dialogs) == self.DIALOGS_AT_ONCE: try: - dialogs = self.send( - functions.messages.GetDialogs( - offset_date, 0, types.InputPeerEmpty(), - self.DIALOGS_AT_ONCE, True - ) - ) + dialogs = self.get_dialogs_chunk(offset_date) except FloodWait as e: log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) time.sleep(e.x) continue - offset_date = parse_dialogs(dialogs) - log.info("Entities count: {}".format(len(self.peers_by_id))) + offset_date = utils.get_offset_date(dialogs) - self.send( - functions.messages.GetDialogs( - 0, 0, types.InputPeerEmpty(), - self.DIALOGS_AT_ONCE, True - ) - ) - - log.info("Entities count: {}".format(len(self.peers_by_id))) + self.get_dialogs_chunk(0) def resolve_peer(self, peer_id: int or str): """Use this method to get the *InputPeer* of a known *peer_id*. @@ -2927,7 +2925,7 @@ class Client: continue else: if isinstance(contacts, types.contacts.Contacts): - log.info("Contacts count: {}".format(len(contacts.users))) + log.info("Total contacts: {}".format(len(self.peers_by_phone))) return contacts diff --git a/pyrogram/client/syncer.py b/pyrogram/client/syncer.py new file mode 100644 index 00000000..fd2a4959 --- /dev/null +++ b/pyrogram/client/syncer.py @@ -0,0 +1,107 @@ +import base64 +import json +import logging +import os +import shutil +import time +from threading import Thread, Event, Lock +from . import utils + +log = logging.getLogger(__name__) + + +class Syncer: + INTERVAL = 20 + + clients = {} + thread = None + event = Event() + lock = Lock() + + @classmethod + def add(cls, client): + with cls.lock: + cls.sync(client) + + cls.clients[id(client)] = client + + if len(cls.clients) == 1: + cls.start() + + @classmethod + def remove(cls, client): + with cls.lock: + cls.sync(client) + + del cls.clients[id(client)] + + if len(cls.clients) == 0: + cls.stop() + + @classmethod + def start(cls): + cls.event.clear() + cls.thread = Thread(target=cls.worker, name=cls.__name__) + cls.thread.start() + + @classmethod + def stop(cls): + cls.event.set() + + @classmethod + def worker(cls): + while True: + cls.event.wait(cls.INTERVAL) + + if cls.event.is_set(): + break + + with cls.lock: + for client in cls.clients.values(): + cls.sync(client) + + @classmethod + def sync(cls, client): + try: + auth_key = base64.b64encode(client.auth_key).decode() + auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] + + data = dict( + dc_id=client.dc_id, + test_mode=client.test_mode, + auth_key=auth_key, + user_id=client.user_id, + date=int(time.time()), + peers_by_id={ + k: getattr(v, "access_hash", None) + for k, v in client.peers_by_id.items() + }, + peers_by_username={ + k: utils.get_peer_id(v) + for k, v in client.peers_by_username.items() + }, + peers_by_phone={ + k: utils.get_peer_id(v) + for k, v in client.peers_by_phone.items() + } + ) + + with open("{}.sync".format(client.session_name), "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + + f.flush() + os.fsync(f.fileno()) + except Exception as e: + log.critical(e, exc_info=True) + else: + shutil.move( + "{}.sync".format(client.session_name), + "{}.session".format(client.session_name) + ) + + log.info("Synced {}".format(client.session_name)) + finally: + try: + os.remove("{}.sync".format(client.session_name)) + except OSError: + pass diff --git a/pyrogram/client/utils.py b/pyrogram/client/utils.py index dc258993..d4b1c38e 100644 --- a/pyrogram/client/utils.py +++ b/pyrogram/client/utils.py @@ -18,6 +18,35 @@ from base64 import b64decode, b64encode +from pyrogram.api import types + + +def get_peer_id(input_peer) -> int: + return ( + input_peer.user_id if isinstance(input_peer, types.InputPeerUser) + else -input_peer.chat_id if isinstance(input_peer, types.InputPeerChat) + else int("-100" + str(input_peer.channel_id)) + ) + + +def get_input_peer(peer_id: int, access_hash: int): + return ( + types.InputPeerUser(peer_id, access_hash) if peer_id > 0 + else types.InputPeerChannel(int(str(peer_id)[4:]), access_hash) + if (str(peer_id).startswith("-100") and access_hash) + else types.InputPeerChat(-peer_id) + ) + + +def get_offset_date(dialogs): + for m in reversed(dialogs.messages): + if isinstance(m, types.MessageEmpty): + continue + else: + return m.date + else: + return 0 + def decode(s: str) -> bytes: s = b64decode(s + "=" * (-len(s) % 4), "-_") diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 509cd10d..118f0d83 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -32,7 +32,7 @@ class Connection: 2: TCPIntermediate } - def __init__(self, address: tuple, proxy: type, mode: int = 1): + def __init__(self, address: tuple, proxy: dict, mode: int = 1): self.address = address self.proxy = proxy self.mode = self.MODES.get(mode, TCPAbridged) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 22b953c1..38005e57 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -18,7 +18,6 @@ import logging import socket -from collections import namedtuple try: import socks @@ -32,29 +31,25 @@ except ImportError as e: log = logging.getLogger(__name__) -Proxy = namedtuple("Proxy", ["enabled", "hostname", "port", "username", "password"]) - class TCP(socks.socksocket): - def __init__(self, proxy: Proxy): + def __init__(self, proxy: dict): super().__init__() self.settimeout(10) - self.proxy_enabled = False - - if proxy and proxy.enabled: - self.proxy_enabled = True + self.proxy_enabled = proxy.get("enabled", False) + if proxy and self.proxy_enabled: self.set_proxy( proxy_type=socks.SOCKS5, - addr=proxy.hostname, - port=proxy.port, - username=proxy.username, - password=proxy.password + addr=proxy["hostname"], + port=proxy["port"], + username=proxy["username"], + password=proxy["password"] ) log.info("Using proxy {}:{}".format( - proxy.hostname, - proxy.port + proxy["hostname"], + proxy["port"] )) def close(self): diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py index acb837af..ad682fed 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged.py @@ -24,7 +24,7 @@ log = logging.getLogger(__name__) class TCPAbridged(TCP): - def __init__(self, proxy: type): + def __init__(self, proxy: dict): super().__init__(proxy) self.is_first_packet = None diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py index 5c1dc2c7..1b131678 100644 --- a/pyrogram/connection/transport/tcp/tcp_full.py +++ b/pyrogram/connection/transport/tcp/tcp_full.py @@ -26,7 +26,7 @@ log = logging.getLogger(__name__) class TCPFull(TCP): - def __init__(self, proxy: type): + def __init__(self, proxy: dict): super().__init__(proxy) self.seq_no = None diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate.py b/pyrogram/connection/transport/tcp/tcp_intermediate.py index 301a88f6..55a7d071 100644 --- a/pyrogram/connection/transport/tcp/tcp_intermediate.py +++ b/pyrogram/connection/transport/tcp/tcp_intermediate.py @@ -25,7 +25,7 @@ log = logging.getLogger(__name__) class TCPIntermediate(TCP): - def __init__(self, proxy: type): + def __init__(self, proxy: dict): super().__init__(proxy) self.is_first_packet = None diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index a1d8fd76..74d45845 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -46,7 +46,7 @@ class Auth: 16 ) - def __init__(self, dc_id: int, test_mode: bool, proxy: type): + def __init__(self, dc_id: int, test_mode: bool, proxy: dict): self.dc_id = dc_id self.test_mode = test_mode diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 2bd59908..5be2eaec 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -86,7 +86,7 @@ class Session: def __init__(self, dc_id: int, test_mode: bool, - proxy: type, + proxy: dict, auth_key: bytes, api_id: int, is_cdn: bool = False,