diff --git a/bin/tests/system/dispatch/ans3/ans.py b/bin/tests/system/dispatch/ans3/ans.py index 4e4ebacb0b..653232f991 100644 --- a/bin/tests/system/dispatch/ans3/ans.py +++ b/bin/tests/system/dispatch/ans3/ans.py @@ -9,91 +9,37 @@ # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. -import os -import select -import signal -import socket -import sys -import time +from typing import AsyncGenerator -import dns.flags -import dns.message +import dns + +from isctest.asyncserver import ( + AsyncDnsServer, + ConnectionReset, + DnsProtocol, + DnsResponseSend, + QueryContext, + ResponseAction, + ResponseHandler, +) -def port(): - env_port = os.getenv("PORT") - if env_port is None: - env_port = 5300 - else: - env_port = int(env_port) - - return env_port +class TruncateOnUdpHandler(ResponseHandler): + async def get_responses( + self, qctx: QueryContext + ) -> AsyncGenerator[ResponseAction, None]: + assert qctx.protocol == DnsProtocol.UDP, "This server only supports UDP" + qctx.response.set_rcode(dns.rcode.NOERROR) + qctx.response.flags |= dns.flags.TC + yield DnsResponseSend(qctx.response) -def udp_listen(port): - udp = socket.socket(type=socket.SOCK_DGRAM) - udp.bind(("10.53.0.3", port)) - - return udp +def main() -> None: + server = AsyncDnsServer() + server.install_connection_handler(ConnectionReset(delay=1.0)) + server.install_response_handler(TruncateOnUdpHandler()) + server.run() -def tcp_listen(port): - tcp = socket.socket(type=socket.SOCK_STREAM) - tcp.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - tcp.bind(("10.53.0.3", port)) - tcp.listen(100) - - return tcp - - -def udp_tc_once(udp): - qrybytes, clientaddr = udp.recvfrom(65535) - qry = dns.message.from_wire(qrybytes) - answ = dns.message.make_response(qry) - answ.flags |= dns.flags.TC - answbytes = answ.to_wire() - udp.sendto(answbytes, clientaddr) - - -def tcp_once(tcp): - csock, _clientaddr = tcp.accept() - time.sleep(5) - csock.close() - - -def sigterm(signum, frame): - os.remove("ans.pid") - sys.exit(0) - - -def write_pid(): - with open("ans.pid", "w") as f: - pid = os.getpid() - f.write("{}".format(pid)) - - -signal.signal(signal.SIGTERM, sigterm) -write_pid() - -udp = udp_listen(port()) -tcp = tcp_listen(port()) - -input = [udp, tcp] - -while True: - try: - inputready, outputready, exceptready = select.select(input, [], []) - except select.error: - break - except socket.error: - break - except KeyboardInterrupt: - break - - for s in inputready: - if s == udp: - udp_tc_once(udp) - if s == tcp: - tcp_once(tcp) - -sigterm(signal.SIGTERM, 0) +if __name__ == "__main__": + main() diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index 784ee8efd2..3f9b4c2c5a 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -373,6 +373,118 @@ class ResponseDrop(ResponseAction): return None +class _ConnectionTeardownRequested(Exception): + pass + + +@dataclass +class ResponseDropAndCloseConnection(ResponseAction): + """ + Action which makes the server close the connection after the DNS query is + received by the server (TCP only). + + The connection may be closed with a delay if requested. + """ + + delay: float = 0.0 + + async def perform(self) -> Optional[Union[dns.message.Message, bytes]]: + if self.delay > 0: + logging.info("Waiting %.1fs before closing TCP connection", self.delay) + await asyncio.sleep(self.delay) + raise _ConnectionTeardownRequested + + +class ConnectionHandler(abc.ABC): + """ + Base class for TCP connection handlers. + + An installed connection handler is called when a new TCP connection is + established. It may be used to perform arbitrary actions before + AsyncDnsServer processes DNS queries. + """ + + @abc.abstractmethod + async def handle( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer + ) -> None: + """ + Handle the connection with the provided reader and writer. + """ + raise NotImplementedError + + +@dataclass +class ConnectionReset(ConnectionHandler): + """ + A connection handler that makes the server close the connection without + reading anything from the client socket. + + The connection may be closed with a delay if requested. + + The sole purpose of this handler is to trigger a connection reset, i.e. to + make the server send an RST segment; this happens when the server closes a + client's socket while there is still unread data in that socket's buffer. + If closing the connection _after_ the query is read by the server is enough + for a given use case, the ResponseDropAndCloseConnection response handler + should be used instead. + """ + + delay: float = 0.0 + + async def handle( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, peer: Peer + ) -> None: + try: + # Python >= 3.7 + loop = asyncio.get_running_loop() + except AttributeError: + # Python < 3.7 + loop = asyncio.get_event_loop() + + logging.info("Blocking reads from %s", peer) + + # This is MichaƂ's submission for the Ugliest Hack of the Year contest. + # (The alternative was implementing an asyncio transport from scratch.) + # + # In order to prevent the client socket from being read from, simply + # not calling `reader.read()` is not enough, because asyncio buffers + # incoming data itself on the transport level. However, `StreamReader` + # does not expose the underlying transport as a property. Therefore, + # cheat by extracting it from `StreamWriter` as it is the same + # bidirectional transport as for the read side (a `Transport`, which is + # a subclass of both `ReadTransport` and `WriteTransport`) and call + # `ReadTransport.pause_reading()` to remove the underlying socket from + # the set of descriptors monitored by the selector, thereby preventing + # any reads from happening on the client socket. However... + loop.call_soon(writer.transport.pause_reading) # type: ignore + + # ...due to `AsyncDnsServer._handle_tcp()` being a coroutine, by the + # time it gets executed, asyncio transport code will already have added + # the client socket to the set of descriptors monitored by the + # selector. Therefore, if the client starts sending data immediately, + # a read from the socket will have already been scheduled by the time + # this handler gets executed. There is no way to prevent that from + # happening, so work around it by abusing the fact that the transport + # at hand is specifically an instance of `_SelectorSocketTransport` + # (from asyncio.selector_events) and set the size of its read buffer to + # just a single byte. This does give asyncio enough time to read that + # single byte from the client socket's buffer before that socket is + # removed from the set of monitored descriptors, but prevents the + # one-off read from emptying the client socket buffer _entirely_, which + # is enough to trigger sending an RST segment when the connection is + # closed shortly afterwards. + writer.transport.max_size = 1 # type: ignore + + if self.delay > 0: + logging.info( + "Waiting %.1fs before closing TCP connection from %s", self.delay, peer + ) + await asyncio.sleep(self.delay) + + raise _ConnectionTeardownRequested + + class ResponseHandler(abc.ABC): """ Base class for generic response handlers. @@ -605,6 +717,7 @@ class AsyncDnsServer(AsyncServer): super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") self._zone_tree: _ZoneTree = _ZoneTree() + self._connection_handler: Optional[ConnectionHandler] = None self._response_handlers: List[ResponseHandler] = [] self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling self._acknowledge_tsig_dnspython_hacks = acknowledge_tsig_dnspython_hacks @@ -637,6 +750,15 @@ class AsyncDnsServer(AsyncServer): logging.info("Uninstalling response handler: %s", handler) self._response_handlers.remove(handler) + def install_connection_handler(self, handler: ConnectionHandler) -> None: + """ + Install a connection handler that will be called when a new TCP + connection is established. + """ + if self._connection_handler: + raise RuntimeError("Only one connection handler can be installed") + self._connection_handler = handler + def _load_zones(self) -> None: for entry in os.scandir(): entry_path = pathlib.Path(entry.path) @@ -684,15 +806,19 @@ class AsyncDnsServer(AsyncServer): peer = Peer(peer_info[0], peer_info[1]) logging.debug("Accepted TCP connection from %s", peer) - while True: - try: + try: + if self._connection_handler: + await self._connection_handler.handle(reader, writer, peer) + while True: wire = await self._read_tcp_query(reader, peer) if not wire: break await self._send_tcp_response(writer, peer, wire) - except ConnectionResetError: - logging.error("TCP connection from %s reset by peer", peer) - return + except _ConnectionTeardownRequested: + pass + except ConnectionResetError: + logging.error("TCP connection from %s reset by peer", peer) + return logging.debug("Closing TCP connection from %s", peer) writer.close()