diff --git a/ubot/command_handler.py b/ubot/command_handler.py index 77bd5ed..037d3ac 100644 --- a/ubot/command_handler.py +++ b/ubot/command_handler.py @@ -1,4 +1,5 @@ import asyncio +from functools import partial from inspect import isawaitable from random import randint from re import DOTALL, IGNORECASE, escape, search @@ -25,6 +26,7 @@ MODERATION_COMMAND_COOLDOWN_SEC = 3 class CommandHandler(): incoming_commands = [] + incoming_lenient_commands = [] inline_photo_commands = [] inline_article_commands = [] callback_queries = [] @@ -36,22 +38,23 @@ class CommandHandler(): self.logger = loader.logger self.db = loader.db self.hard_prefix = self.settings.get_list("hard_cmd_prefix") or ["/"] - self.micro_bot.client.add_event_handler(self.report_incoming_excepts, events.NewMessage(incoming=True, forwards=False, func=lambda e: e.raw_text)) + self.micro_bot.client.add_event_handler(partial(self.report_incoming_excepts, self.incoming_commands), events.NewMessage(incoming=True, forwards=False, func=lambda e: e.raw_text)) + self.micro_bot.client.add_event_handler(partial(self.report_incoming_excepts, self.incoming_lenient_commands), events.NewMessage(incoming=True)) self.micro_bot.client.add_event_handler(self.handle_inline, events.InlineQuery()) self.micro_bot.client.add_event_handler(self.handle_callback_query, events.CallbackQuery()) - async def report_incoming_excepts(self, event): + async def report_incoming_excepts(self, command_list, event): try: - await self.handle_incoming(event) + await self.handle_incoming(event, command_list) except Exception as exception: if not isinstance(exception, (ChatAdminRequiredError, ChatWriteForbiddenError)): await event.client.send_message(int(self.settings.get_list("owner_id")[0]), str(format_exc())) - async def handle_incoming(self, event: ExtendedNewMessage): + async def handle_incoming(self, event: ExtendedNewMessage, command_list: list[Command]): chat_db = self.db.get_chat((await event.get_chat()).id) chat_prefix = chat_db.prefix - for command in self.incoming_commands: + for command in command_list: if command.simple_pattern: pattern_match = search(SIMPLE_PATTERN_TEMPLATE.format(command.pattern + command.pattern_extra), event.raw_text, IGNORECASE|DOTALL) elif command.raw_pattern: diff --git a/ubot/loader.py b/ubot/loader.py index d71c6f2..408373d 100644 --- a/ubot/loader.py +++ b/ubot/loader.py @@ -56,6 +56,7 @@ class Loader(): def reload_all_modules(self): self.command_handler.incoming_commands = [] + self.command_handler.incoming_lenient_commands = [] self.command_handler.inline_photo_commands = [] self.command_handler.inline_article_commands = [] self.command_handler.callback_queries = [] @@ -82,6 +83,15 @@ class Loader(): return decorator + def add_lenient(self, pattern: str = None, **args): + def decorator(func): + args["pattern"] = args.get("pattern", pattern) + self.command_handler.incoming_lenient_commands.append(Command(func, args)) + + return func + + return decorator + def add_list(self, pattern: list = None, **args): pattern_list = args.get("pattern", pattern)