diff --git a/ubot/__init__.py b/ubot/__init__.py index 5e9108d..6679629 100644 --- a/ubot/__init__.py +++ b/ubot/__init__.py @@ -4,13 +4,15 @@ from logging import INFO, basicConfig, getLogger from time import time import telethon +from telethon import TelegramClient from telethon.errors.rpcerrorlist import (AccessTokenExpiredError, AccessTokenInvalidError, TokenInvalidError) from telethon.network.connection.tcpabridged import \ ConnectionTcpAbridged as CTA -from .custom import ExtendedEvent +from .custom import (ExtendedCallbackQuery, ExtendedInlineQuery, + ExtendedNewMessage) from .loader import Loader from .settings import Settings @@ -26,9 +28,9 @@ loop = asyncio.get_event_loop() class MicroBot(): settings = Settings() - client = None - logger = None - loader = None + logger = logger + client: TelegramClient + loader: Loader def __init__(self): loop.run_until_complete(self._initialize_bot()) @@ -37,7 +39,7 @@ class MicroBot(): global ldr try: - self.client = await telethon.TelegramClient( + self.client = await TelegramClient( self.settings.get_config("session_name", "bot0") or "bot0", self.settings.get_config("api_id"), self.settings.get_config("api_hash"), @@ -76,7 +78,10 @@ class MicroBot(): sys.exit(0) -telethon.events.NewMessage.Event = ExtendedEvent +telethon.events.NewMessage.Event = ExtendedNewMessage +telethon.events.CallbackQuery.Event = ExtendedCallbackQuery +telethon.events.InlineQuery.Event = ExtendedInlineQuery + micro_bot = MicroBot() try: diff --git a/ubot/command_handler.py b/ubot/command_handler.py index 807d334..3ab934c 100644 --- a/ubot/command_handler.py +++ b/ubot/command_handler.py @@ -8,6 +8,11 @@ from telethon import events from telethon.errors.rpcerrorlist import (ChatAdminRequiredError, ChatWriteForbiddenError) +from ubot.command import CallbackQueryCommand, Command +from ubot.custom import (ExtendedCallbackQuery, ExtendedInlineQuery, + ExtendedNewMessage) +from ubot.database import ChatWrapper + from .fixes import inline_photos @@ -39,7 +44,7 @@ class CommandHandler(): if not isinstance(exception, (ChatAdminRequiredError, ChatWriteForbiddenError)): await event.client.send_message(int(self.settings.get_list("owner_id")[0]), str(format_exc())) - async def handle_incoming(self, event): + async def handle_incoming(self, event: ExtendedNewMessage): chat_db = self.db.get_chat((await event.get_chat()).id) chat_prefix = chat_db.prefix @@ -100,7 +105,7 @@ class CommandHandler(): await self.execute_command(event, command) - async def handle_inline(self, event): + async def handle_inline(self, event: ExtendedInlineQuery): for command in self.inline_photo_commands: pattern_match = search(self.simple_pattern_template.format(command.pattern + command.pattern_extra), event.text, IGNORECASE|DOTALL) @@ -123,7 +128,7 @@ class CommandHandler(): await self.fallback_inline(event) - async def handle_inline_photo(self, event, pattern_match, command): + async def handle_inline_photo(self, event: ExtendedInlineQuery, pattern_match, command): builder = event.builder event.pattern_match = pattern_match event.args = pattern_match.groups()[-1] @@ -166,7 +171,7 @@ class CommandHandler(): except: print_exc() - async def handle_inline_article(self, event, pattern_match, command): + async def handle_inline_article(self, event: ExtendedInlineQuery, pattern_match, command): builder = event.builder event.pattern_match = pattern_match event.args = pattern_match.groups()[-1] @@ -204,7 +209,7 @@ class CommandHandler(): except: print_exc() - async def handle_callback_query(self, event): + async def handle_callback_query(self, event: ExtendedCallbackQuery): data_str = event.data.decode("utf-8") data_id = data_str.split("*")[0] data_data = data_str.removeprefix(data_id + "*") @@ -283,7 +288,7 @@ class CommandHandler(): # returns True if the command can be used, False if not, and an optional error string together in a tuple # for normal commands, this will be passed to event.reply; for callback queries this will call event.answer - async def check_privs(self, event, command, chat_db = None, callback_query = False) -> tuple[bool, str|None]: + async def check_privs(self, event, command: Command|CallbackQueryCommand, chat_db: ChatWrapper|None = None, callback_query = False) -> tuple[bool, str|None]: if self.is_blacklisted(event) and not self.is_owner(event) and not self.is_sudo(event): return (False, None) @@ -301,10 +306,10 @@ class CommandHandler(): if event.is_private or not (await event.client.get_permissions(event.chat, event.sender_id)).is_admin and not self.is_sudo(event) and not self.is_owner(event): return (False, None if command.silent_bail else "You lack the permissions to use that command!") - if not callback_query and event.chat and command.nsfw and not chat_db.nsfw_enabled: + if not callback_query and event.chat and command.nsfw and (chat_db and not chat_db.nsfw_enabled): return (False, None if command.silent_bail else command.nsfw_warning or "NSFW commands are disabled in this chat!") - if not callback_query and event.chat and command.fun and not chat_db.fun_enabled: + if not callback_query and event.chat and command.fun and (chat_db and not chat_db.fun_enabled): return (False, None) return (True, None) diff --git a/ubot/custom.py b/ubot/custom.py index e51e110..dbb8a02 100644 --- a/ubot/custom.py +++ b/ubot/custom.py @@ -1,10 +1,29 @@ +from re import Match +from typing import Any + +from telethon.events.callbackquery import CallbackQuery +from telethon.events.inlinequery import InlineQuery from telethon.events.newmessage import NewMessage from telethon.tl.types import (DocumentAttributeFilename, DocumentAttributeImageSize, DocumentAttributeSticker) +from ubot.command import (CallbackQueryCommand, Command, InlineArticleCommand, + InlinePhotoCommand) +from ubot.database import ChatWrapper + + +class ExtendedNewMessage(NewMessage.Event): + pattern_match: Match[str] # pattern match as returned by re.search when it's used in the command handler + chat_db: ChatWrapper # database reference for the chat this command was executed in + object: Command # the object constructed when the command associated with this event was added + command: str # the base command with no prefix, no args and no other_args; the whole pattern if raw_pattern is used + prefix: str # prefix used to call this command, such as "/" or "g."; not set if simple_pattern is used + extra: Any # any object you set to extra when registering the command associated with this event + args: str # anything after the command itself and any groups caught in other_args, such as booru tags + other_args: tuple # any groups between the args group and the command itself + nsfw_disabled: bool # only set if pass_nsfw is True; this value is the opposite of nsfw_enabled in chat_db -class ExtendedEvent(NewMessage.Event): async def get_text(self, return_msg=False, default=""): if self.args: if return_msg: @@ -58,3 +77,23 @@ class ExtendedEvent(NewMessage.Event): return await self.message.respond(*args, **kwargs|{"reply_to": self.reply_to.reply_to_msg_id}) return await self.message.respond(*args, **kwargs) + + +class ExtendedCallbackQuery(CallbackQuery.Event): + chat_db: ChatWrapper|None + object: CallbackQueryCommand + command: str + extra: Any + args: str + + +class ExtendedInlineQuery(InlineQuery.Event): + pattern_match: Match[str] + parse_mode: str + object: InlineArticleCommand|InlinePhotoCommand + command: str + extra: Any + args: str + other_args: tuple + nsfw_disabled: bool + link_preview: bool