From b16bdbb36684c63df3eab5522af7ed90ebc851b0 Mon Sep 17 00:00:00 2001 From: Nick80835 Date: Mon, 24 Aug 2020 12:07:23 -0400 Subject: [PATCH] add command disabling and a database --- .gitignore | 2 +- ubot/command_handler.py | 10 ++++++++- ubot/database.py | 48 +++++++++++++++++++++++++++++++++++++++++ ubot/loader.py | 3 +++ ubot/modules/system.py | 41 +++++++++++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 ubot/database.py diff --git a/.gitignore b/.gitignore index 547d38b..d2f1244 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ __pycache__ settings.ini *.session *.session-journal -*.db +*.sqlite testing.py cache \ No newline at end of file diff --git a/ubot/command_handler.py b/ubot/command_handler.py index 20351ca..0cfdf09 100644 --- a/ubot/command_handler.py +++ b/ubot/command_handler.py @@ -49,10 +49,15 @@ class CommandHandler(): if value["pass_nsfw"]: event.nsfw_disabled = str(event.chat.id) in self.settings.get_list("nsfw_blacklist") + event.command = pattern_match.groups()[1] + + if event.command in self.loader.db.get_disabled_commands(event.chat.id): + print(f"Attempted command ({event.raw_text}) in chat which disabled it ({event.chat.id}) from ID {event.from_id}") + return + event.pattern_match = pattern_match event.args = pattern_match.groups()[-1].strip() event.other_args = pattern_match.groups()[2:-1] - event.command = pattern_match.groups()[1] event.extra = value["extra"] await self.execute_command(event, value) @@ -231,6 +236,9 @@ class CommandHandler(): return bool(str(event.from_id) in self.settings.get_list("sudo_users")) async def is_admin(self, event): + if event.is_private: + return True + channel_participant = await event.client(functions.channels.GetParticipantRequest(event.chat, event.from_id)) return bool(isinstance(channel_participant.participant, (types.ChannelParticipantAdmin, types.ChannelParticipantCreator))) diff --git a/ubot/database.py b/ubot/database.py new file mode 100644 index 0000000..dd2102d --- /dev/null +++ b/ubot/database.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: GPL-2.0-or-later + +import json +import sqlite3 + + +class Database(): + db_conn = sqlite3.connect("database.sqlite") + + def __init__(self): + cur = self.db_conn.cursor() + cur.execute( + """CREATE TABLE IF NOT EXISTS chats ( + id integer PRIMARY KEY, + disabled_commands text NOT NULL + );""" + ) + + def ensure_chat_table(self, chat_id: int): + cur = self.db_conn.cursor() + cur.execute("INSERT OR IGNORE INTO chats (id, disabled_commands) VALUES (?, ?);", [str(chat_id), "[]"]) + self.db_conn.commit() + + def get_disabled_commands(self, chat_id: int) -> list: + self.ensure_chat_table(chat_id) + cur = self.db_conn.cursor() + disabled_raw = cur.execute("SELECT disabled_commands FROM chats WHERE id = ?;", [str(chat_id)]).fetchone() + return json.loads(disabled_raw[0] if disabled_raw else "[]") + + def disable_command(self, chat_id: int, command: str): + disabled_commands = self.get_disabled_commands(chat_id) + + if command not in disabled_commands: + disabled_commands.append(command) + new_disabled_commands = json.dumps(disabled_commands) + cur = self.db_conn.cursor() + cur.execute("UPDATE chats SET disabled_commands = ? WHERE id = ?;", [new_disabled_commands, str(chat_id)]) + self.db_conn.commit() + + def enable_command(self, chat_id: int, command: str): + disabled_commands = self.get_disabled_commands(chat_id) + + if command in disabled_commands: + disabled_commands.remove(command) + new_disabled_commands = json.dumps(disabled_commands) + cur = self.db_conn.cursor() + cur.execute("UPDATE chats SET disabled_commands = ? WHERE id = ?;", [new_disabled_commands, str(chat_id)]) + self.db_conn.commit() diff --git a/ubot/loader.py b/ubot/loader.py index 38be17b..9f30e7f 100644 --- a/ubot/loader.py +++ b/ubot/loader.py @@ -6,16 +6,19 @@ from importlib import import_module, reload from os.path import basename, dirname, isfile from aiohttp import ClientSession + from telethon.tl.types import DocumentAttributeFilename from .cache import Cache from .command_handler import CommandHandler +from .database import Database class Loader(): aioclient = ClientSession() thread_pool = ThreadPoolExecutor() cache = Cache(aioclient) + db = Database() help_dict = {} help_hidden_dict = {} diff --git a/ubot/modules/system.py b/ubot/modules/system.py index 8169254..9728947 100644 --- a/ubot/modules/system.py +++ b/ubot/modules/system.py @@ -75,6 +75,47 @@ async def bot_repo(event): await event.reply("https://github.com/Nick80835/microbot") +@ldr.add("disable", admin=True, help="Disables commands in the current chat, requires admin.") +async def disable_command(event): + if event.args: + for value in ldr.help_dict.values(): + for info in [i[0] for i in value]: + if event.args == info: + await event.reply(f"Disabling **{info}** in chat **{event.chat.id}**!") + ldr.db.disable_command(event.chat.id, info) + return + + await event.reply(f"**{event.args}** is not a command!") + else: + await event.reply(f"Specify a command to disable!") + + +@ldr.add("enable", admin=True, help="Enables commands in the current chat, requires admin.") +async def enable_command(event): + if event.args: + for value in ldr.help_dict.values(): + for info in [i[0] for i in value]: + if event.args == info: + await event.reply(f"Enabling **{info}** in chat **{event.chat.id}**!") + ldr.db.enable_command(event.chat.id, info) + return + + await event.reply(f"**{event.args}** is not a command!") + else: + await event.reply(f"Specify a command to enable!") + + +@ldr.add("showdisabled", admin=True, help="Shows disabled commands in the current chat.") +async def show_disabled(event): + disabled_list = ldr.db.get_disabled_commands(event.chat.id) + + if disabled_list: + disabled_commands = "\n".join(ldr.db.get_disabled_commands(event.chat.id)) + await event.reply(f"Disabled commands in **{event.chat.id}**:\n\n{disabled_commands}") + else: + await event.reply(f"There are no disabled commands in **{event.chat.id}**!") + + @ldr.add("nsfw", admin=True, help="Enables or disables NSFW commands for a chat, requires admin.") async def nsfw_toggle(event): if not event.args or event.args not in ("on", "off"):