diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 576dea9a..fe4de3ff 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -157,10 +157,8 @@ class Client(Methods, BaseClient): config_file (``str``, *optional*): Path of the configuration file. Defaults to ./config.ini - plugins_dir (``str``, *optional*): - Define a custom directory for your plugins. The plugins directory is the location in your - filesystem where Pyrogram will automatically load your update handlers. - Defaults to None (plugins disabled). + plugins (``dict``, *optional*): + TODO: doctrings no_updates (``bool``, *optional*): Pass True to completely disable incoming updates for the current session. @@ -197,7 +195,7 @@ class Client(Methods, BaseClient): workers: int = BaseClient.WORKERS, workdir: str = BaseClient.WORKDIR, config_file: str = BaseClient.CONFIG_FILE, - plugins_dir: str = None, + plugins: dict = None, no_updates: bool = None, takeout: bool = None): super().__init__() @@ -223,7 +221,7 @@ class Client(Methods, BaseClient): self.workers = workers self.workdir = workdir self.config_file = config_file - self.plugins_dir = plugins_dir + self.plugins = plugins self.no_updates = no_updates self.takeout = takeout @@ -1074,6 +1072,38 @@ class Client(Methods, BaseClient): self._proxy["username"] = parser.get("proxy", "username", fallback=None) or None self._proxy["password"] = parser.get("proxy", "password", fallback=None) or None + if self.plugins: + self.plugins["enabled"] = bool(self.plugins.get("enabled", True)) + else: + self.plugins = {} + + try: + section = parser["plugins"] + + include = section.get("include") or None + exclude = section.get("exclude") or None + + if include is not None: + include = [ + (i.split()[0], i.split()[1:] or None) + for i in include.strip().split("\n") + ] + + if exclude is not None: + exclude = [ + (i.split()[0], i.split()[1:] or None) + for i in exclude.strip().split("\n") + ] + + self.plugins["enabled"] = section.getboolean("enabled", True) + self.plugins["root"] = section.get("root") + self.plugins["include"] = include + self.plugins["exclude"] = exclude + except KeyError: + pass + else: + print(self.plugins) + def load_session(self): try: with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f: @@ -1105,43 +1135,112 @@ class Client(Methods, BaseClient): self.peers_by_phone[k] = peer def load_plugins(self): - if self.plugins_dir is not None: + if self.plugins.get("enabled", False): + root = self.plugins["root"] + include = self.plugins["include"] + exclude = self.plugins["exclude"] + plugins_count = 0 - for path in Path(self.plugins_dir).rglob("*.py"): - file_path = os.path.splitext(str(path))[0] - import_path = [] + if include is None: + for path in sorted(Path(root).rglob("*.py")): + module_path = os.path.splitext(str(path))[0].replace("/", ".") + module = import_module(module_path) - while file_path: - file_path, tail = os.path.split(file_path) - import_path.insert(0, tail) + for name in vars(module).keys(): + # noinspection PyBroadException + try: + handler, group = getattr(module, name) - import_path = ".".join(import_path) - module = import_module(import_path) + if isinstance(handler, Handler) and isinstance(group, int): + self.add_handler(handler, group) + + log.info('[LOAD] {}("{}") in group {} from "{}"'.format( + type(handler).__name__, name, group, module_path)) + + plugins_count += 1 + except Exception: + pass + else: + for path, handlers in include: + module_path = root + "." + path + warn_non_existent_functions = True - for name in dir(module): - # noinspection PyBroadException try: - handler, group = getattr(module, name) + module = import_module(module_path) + except ModuleNotFoundError: + log.warning('[LOAD] Ignoring non-existent module "{}"'.format(module_path)) + continue - if isinstance(handler, Handler) and isinstance(group, int): - self.add_handler(handler, group) + if "__path__" in dir(module): + log.warning('[LOAD] Ignoring namespace "{}"'.format(module_path)) + continue - log.info('{}("{}") from "{}" loaded in group {}'.format( - type(handler).__name__, name, import_path, group)) + if handlers is None: + handlers = vars(module).keys() + warn_non_existent_functions = False - plugins_count += 1 - except Exception: - pass + for name in handlers: + # noinspection PyBroadException + try: + handler, group = getattr(module, name) + + if isinstance(handler, Handler) and isinstance(group, int): + self.add_handler(handler, group) + + log.info('[LOAD] {}("{}") in group {} from "{}"'.format( + type(handler).__name__, name, group, module_path)) + + plugins_count += 1 + except Exception: + if warn_non_existent_functions: + log.warning('[LOAD] Ignoring non-existent function "{}" from "{}"'.format( + name, module_path)) + + if exclude is not None: + for path, handlers in exclude: + module_path = root + "." + path + warn_non_existent_functions = True + + try: + module = import_module(module_path) + except ModuleNotFoundError: + log.warning('[UNLOAD] Ignoring non-existent module "{}"'.format(module_path)) + continue + + if "__path__" in dir(module): + log.warning('[UNLOAD] Ignoring namespace "{}"'.format(module_path)) + continue + + if handlers is None: + handlers = vars(module).keys() + warn_non_existent_functions = False + + for name in handlers: + # noinspection PyBroadException + try: + handler, group = getattr(module, name) + + if isinstance(handler, Handler) and isinstance(group, int): + self.remove_handler(handler, group) + + log.info('[UNLOAD] {}("{}") from group {} in "{}"'.format( + type(handler).__name__, name, group, module_path)) + + plugins_count -= 1 + except Exception: + if warn_non_existent_functions: + log.warning('[UNLOAD] Ignoring non-existent function "{}" from "{}"'.format( + name, module_path)) if plugins_count > 0: log.warning('Successfully loaded {} plugin{} from "{}"'.format( plugins_count, "s" if plugins_count > 1 else "", - self.plugins_dir + root )) else: - log.warning('No plugin loaded: "{}" doesn\'t contain any valid plugin'.format(self.plugins_dir)) + log.warning('No plugin loaded: "{}" doesn\'t contain any valid plugin'.format(root)) def save_session(self): auth_key = base64.b64encode(self.auth_key).decode()