diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index dd2cf0c04c..fed86f6f33 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -17,9 +17,11 @@ from typing import ( AsyncGenerator, Callable, Coroutine, + Dict, List, Optional, Tuple, + Type, Union, cast, ) @@ -160,6 +162,7 @@ class AsyncServer: loop.run_until_complete(coroutine()) async def _run(self) -> None: + self._setup_exception_handler() self._setup_signals() assert self._work_done await self._listen_udp() @@ -177,9 +180,20 @@ class AsyncServer: loop = asyncio.get_event_loop() return loop - def _setup_signals(self) -> None: + def _setup_exception_handler(self) -> None: loop = self._get_asyncio_loop() self._work_done = loop.create_future() + loop.set_exception_handler(self._handle_exception) + + def _handle_exception( + self, _: asyncio.AbstractEventLoop, context: Dict[str, Any] + ) -> None: + assert self._work_done + exception = context.get("exception", RuntimeError(context["message"])) + self._work_done.set_exception(exception) + + def _setup_signals(self) -> None: + loop = self._get_asyncio_loop() loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done)) loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done)) @@ -382,6 +396,9 @@ class ResponseHandler(abc.ABC): """ yield DnsResponseSend(qctx.response) + def __str__(self) -> str: + return self.__class__.__name__ + class IgnoreAllQueries(ResponseHandler): """ @@ -520,15 +537,31 @@ class AsyncDnsServer(AsyncServer): if load_zones: self._load_zones() - def install_response_handler(self, handler: ResponseHandler) -> None: + def install_response_handler( + self, handler: ResponseHandler, prepend: bool = False + ) -> None: """ - Add a response handler which will be used to handle matching queries. + Add a response handler that will be used to handle matching queries. Response handlers can modify, replace, or suppress the answers prepared from zone file contents. + + The provided handler is installed at the end of the response handler + list unless `prepend` is set to True, in which case it is installed at + the beginning of the response handler list. """ logging.info("Installing response handler: %s", handler) - self._response_handlers.append(handler) + if prepend: + self._response_handlers.insert(0, handler) + else: + self._response_handlers.append(handler) + + def uninstall_response_handler(self, handler: ResponseHandler) -> None: + """ + Remove the specified handler from the list of response handlers. + """ + logging.info("Uninstalling response handler: %s", handler) + self._response_handlers.remove(handler) def _load_zones(self) -> None: for entry in os.scandir(): @@ -568,7 +601,12 @@ class AsyncDnsServer(AsyncServer): logging.debug("Closing TCP connection from %s", peer) writer.close() - await writer.wait_closed() + try: + # Python >= 3.7 + await writer.wait_closed() + except AttributeError: + # Python < 3.7 + pass async def _read_tcp_query( self, reader: asyncio.StreamReader, peer: Peer @@ -711,7 +749,11 @@ class AsyncDnsServer(AsyncServer): """ Yield wire data to send as a response over the established transport. """ - query = dns.message.from_wire(wire) + try: + query = dns.message.from_wire(wire) + except dns.exception.DNSException as exc: + logging.error("Invalid query from %s (%s): %s", peer, wire.hex(), exc) + return response_stub = dns.message.make_response(query) qctx = QueryContext(query, response_stub, peer, protocol) self._log_query(qctx, peer, protocol) @@ -741,6 +783,7 @@ class AsyncDnsServer(AsyncServer): response_handled = True if not response_handled: + logging.debug("Responding based on zone data") yield qctx.response def _prepare_response_from_zone_data(self, qctx: QueryContext) -> None: @@ -874,6 +917,194 @@ class AsyncDnsServer(AsyncServer): """ for handler in self._response_handlers: if handler.match(qctx): + logging.debug("Matched response handler: %s", handler) async for response in handler.get_responses(qctx): yield response return + + +class ControllableAsyncDnsServer(AsyncDnsServer): + """ + An AsyncDnsServer whose behavior can be dynamically changed by sending TXT + queries to a "magic" domain. + """ + + _CONTROL_DOMAIN = "_control." + + def __init__(self, commands: List[Type["ControlCommand"]]): + super().__init__() + self._control_domain = dns.name.from_text(self._CONTROL_DOMAIN) + self._commands: Dict[dns.name.Name, "ControlCommand"] = {} + for command_class in commands: + command = command_class() + command_subdomain = dns.name.Name([command.control_subdomain]) + control_subdomain = command_subdomain.concatenate(self._control_domain) + try: + existing_command = self._commands[control_subdomain] + except KeyError: + self._commands[control_subdomain] = command + else: + raise RuntimeError( + f"{control_subdomain} already handled by {existing_command}" + ) + + async def _prepare_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]: + """ + Detect and handle control queries, falling back to normal processing + for non-control queries. + """ + control_response = self._handle_control_command(qctx) + if control_response: + yield await DnsResponseSend(response=control_response).perform() + return + + async for response in super()._prepare_responses(qctx): + yield response + + def _handle_control_command( + self, qctx: QueryContext + ) -> Optional[dns.message.Message]: + """ + Detect and handle control queries. + + A control query must be of type TXT; if it is not, a FORMERR response + is sent back. + + The list of commands that the server should respond to is passed to its + constructor. If the server is unable to handle the control query using + any of the enabled commands, an NXDOMAIN response is sent. + + Otherwise, the relevant command's handler is expected to provide the + response via qctx.response and/or return a string that is converted to + a TXT RRset inserted into the ANSWER section of the response to the + control query. The RCODE for a command-provided response defaults to + NOERROR, but can be overridden by the command's handler. + """ + if not qctx.qname.is_subdomain(self._control_domain): + return None + + if qctx.qtype != dns.rdatatype.TXT: + logging.error("Non-TXT control query %s from %s", qctx.qname, qctx.peer) + qctx.response.set_rcode(dns.rcode.FORMERR) + return qctx.response + + control_subdomain = dns.name.Name(qctx.qname.labels[-3:]) + try: + command = self._commands[control_subdomain] + except KeyError: + logging.error("Unhandled control query %s from %s", qctx.qname, qctx.peer) + qctx.response.set_rcode(dns.rcode.NXDOMAIN) + return qctx.response + + logging.info("Received control query %s from %s", qctx.qname, qctx.peer) + logging.debug("Handling control query %s using %s", qctx.qname, command) + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.flags |= dns.flags.AA + + command_qname = qctx.qname.relativize(control_subdomain) + try: + command_args = [l.decode("ascii") for l in command_qname.labels] + except UnicodeDecodeError: + logging.error("Non-ASCII control query %s from %s", qctx.qname, qctx.peer) + qctx.response.set_rcode(dns.rcode.FORMERR) + return qctx.response + + command_response = command.handle(command_args, self, qctx) + if command_response: + command_response_rrset = dns.rrset.from_text( + qctx.qname, 0, qctx.qclass, dns.rdatatype.TXT, f'"{command_response}"' + ) + qctx.response.answer.append(command_response_rrset) + + return qctx.response + + +class ControlCommand(abc.ABC): + """ + Base class for control commands. + + The derived class must define the control query subdomain that it handles + and the callback that handles the control queries. + """ + + @property + @abc.abstractmethod + def control_subdomain(self) -> str: + """ + The subdomain of the control domain handled by this command. Needs to + be defined as a string by the derived class. + """ + raise NotImplementedError + + @abc.abstractmethod + def handle( + self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext + ) -> Optional[str]: + """ + This method is expected to carry out arbitrary actions in response to a + control query. Note that it is invoked synchronously (it is not a + coroutine). + + `args` is a list of arguments for the command extracted from the + control query's QNAME; these arguments (and therefore the QNAME as + well) must only contain ASCII characters. For example, if a command's + subdomain is `my-command`, control query `foo.bar.my-command._control.` + causes `args` to be set to `["foo", "bar"]` while control query + `my-command._control.` causes `args` to be set to `[]`. + + `server` is the server instance that received the control query. This + method can change the server's behavior by altering its response + handler list using the appropriate methods. + + `qctx` is the query context for the control query. By operating on + qctx.response, this method can prepare the DNS response sent to + the client in response to the control query. Alternatively (or in + addition to the above), it can also return a string; if it does, the + returned string is converted to a TXT RRset that is inserted into the + ANSWER section of the response to the control query. + """ + raise NotImplementedError + + def __str__(self) -> str: + return self.__class__.__name__ + + +class ToggleResponsesCommand(ControlCommand): + """ + Disable/enable sending responses from the server. + """ + + control_subdomain = "send-responses" + + def __init__(self) -> None: + self._current_handler: Optional[IgnoreAllQueries] = None + + def handle( + self, args: List[str], server: ControllableAsyncDnsServer, qctx: QueryContext + ) -> Optional[str]: + if len(args) != 1: + logging.error("Invalid %s query %s", self, qctx.qname) + qctx.response.set_rcode(dns.rcode.SERVFAIL) + return "invalid query; use exactly one of 'enable' or 'disable' in QNAME" + + mode = args[0] + + if mode == "disable": + if self._current_handler: + return "sending responses already disabled" + self._current_handler = IgnoreAllQueries() + server.install_response_handler(self._current_handler, prepend=True) + return "sending responses disabled" + + if mode == "enable": + if not self._current_handler: + return "sending responses already enabled" + server.uninstall_response_handler(self._current_handler) + self._current_handler = None + return "sending responses enabled" + + logging.error("Unrecognized response sending mode '%s'", mode) + qctx.response.set_rcode(dns.rcode.SERVFAIL) + return f"unrecognized response sending mode '{mode}'" diff --git a/bin/tests/system/qmin/ans2/ans.py b/bin/tests/system/qmin/ans2/ans.py index 7fa6a6c2c5..18f077781e 100644 --- a/bin/tests/system/qmin/ans2/ans.py +++ b/bin/tests/system/qmin/ans2/ans.py @@ -101,7 +101,7 @@ class StaleHandler(DomainHandler): yield send_delegation(qctx, b_stale, "10.53.0.4") -if __name__ == "__main__": +def main() -> None: server = AsyncDnsServer() server.install_response_handler(QueryLogger()) server.install_response_handler(BadHandler()) @@ -109,3 +109,7 @@ if __name__ == "__main__": server.install_response_handler(SlowHandler()) server.install_response_handler(StaleHandler()) server.run() + + +if __name__ == "__main__": + main() diff --git a/bin/tests/system/qmin/ans3/ans.py b/bin/tests/system/qmin/ans3/ans.py index 057bbb34d5..6547dd2f9b 100644 --- a/bin/tests/system/qmin/ans3/ans.py +++ b/bin/tests/system/qmin/ans3/ans.py @@ -37,10 +37,14 @@ class ZoopBoingSlowHandler(DelayedResponseHandler): delay = 0.4 -if __name__ == "__main__": +def main() -> None: server = AsyncDnsServer() server.install_response_handler(QueryLogger()) server.install_response_handler(ZoopBoingBadHandler()) server.install_response_handler(ZoopBoingUglyHandler()) server.install_response_handler(ZoopBoingSlowHandler()) server.run() + + +if __name__ == "__main__": + main() diff --git a/bin/tests/system/qmin/ans4/ans.py b/bin/tests/system/qmin/ans4/ans.py index ca43845a1d..ebe500bad6 100644 --- a/bin/tests/system/qmin/ans4/ans.py +++ b/bin/tests/system/qmin/ans4/ans.py @@ -83,7 +83,7 @@ class IckyPtangZoopBoingSlowHandler(DelayedResponseHandler): delay = 0.4 -if __name__ == "__main__": +def main() -> None: server = AsyncDnsServer() server.install_response_handler(QueryLogger()) server.install_response_handler(StaleHandler()) @@ -91,3 +91,7 @@ if __name__ == "__main__": server.install_response_handler(IckyPtangZoopBoingUglyHandler()) server.install_response_handler(IckyPtangZoopBoingSlowHandler()) server.run() + + +if __name__ == "__main__": + main() diff --git a/bin/tests/system/upforwd/ans4/ans.py b/bin/tests/system/upforwd/ans4/ans.py index bd6e863bd7..9c5f940b5c 100644 --- a/bin/tests/system/upforwd/ans4/ans.py +++ b/bin/tests/system/upforwd/ans4/ans.py @@ -14,7 +14,11 @@ information regarding copyright ownership. from isctest.asyncserver import AsyncDnsServer, IgnoreAllQueries -if __name__ == "__main__": +def main() -> None: server = AsyncDnsServer() server.install_response_handler(IgnoreAllQueries()) server.run() + + +if __name__ == "__main__": + main()