diff --git a/bin/tests/system/isctest/asyncserver.py b/bin/tests/system/isctest/asyncserver.py index fed86f6f33..2341f03110 100644 --- a/bin/tests/system/isctest/asyncserver.py +++ b/bin/tests/system/isctest/asyncserver.py @@ -266,11 +266,16 @@ class QueryContext: soa: Optional[dns.rrset.RRset] = None node: Optional[dns.node.Node] = None answer: Optional[dns.rdataset.Rdataset] = None + alias: Optional[dns.name.Name] = None @property def qname(self) -> dns.name.Name: return self.query.question[0].name + @property + def current_qname(self) -> dns.name.Name: + return self.alias or self.qname + @property def qclass(self) -> RdataClass: return self.query.question[0].rdclass @@ -528,14 +533,14 @@ class AsyncDnsServer(AsyncServer): response from scratch, without using zone data at all. """ - def __init__(self, load_zones: bool = True): + def __init__(self, acknowledge_manual_dname_handling: bool = False) -> None: super().__init__(self._handle_udp, self._handle_tcp, "ans.pid") self._zone_tree: _ZoneTree = _ZoneTree() self._response_handlers: List[ResponseHandler] = [] + self._acknowledge_manual_dname_handling = acknowledge_manual_dname_handling - if load_zones: - self._load_zones() + self._load_zones() def install_response_handler( self, handler: ResponseHandler, prepend: bool = False @@ -568,11 +573,31 @@ class AsyncDnsServer(AsyncServer): entry_path = pathlib.Path(entry.path) if entry_path.suffix != ".db": continue - origin = dns.name.from_text(entry_path.stem) - logging.info("Loading zone file %s", entry_path) - zone = dns.zone.from_file(entry.path, origin, relativize=False) + zone = self._load_zone(entry_path) self._zone_tree.add(zone) + def _load_zone(self, zone_file_path: pathlib.Path) -> dns.zone.Zone: + origin = dns.name.from_text(zone_file_path.stem) + logging.info("Loading zone file %s", zone_file_path) + with open(zone_file_path, encoding="utf-8") as zone_file: + zone = dns.zone.from_file(zone_file, origin, relativize=False) + self._abort_if_dname_found_unless_acknowledged(zone) + return zone + + def _abort_if_dname_found_unless_acknowledged(self, zone: dns.zone.Zone) -> None: + if self._acknowledge_manual_dname_handling: + return + + error = f'DNAME records found in zone "{zone.origin}"; ' + error += "this server does not handle DNAME in a standards-compliant way; " + error += "pass `acknowledge_manual_dname_handling=True` to the " + error += "AsyncDnsServer constructor to acknowledge this and load zone anyway" + + for node in zone.nodes.values(): + for rdataset in node: + if rdataset.rdtype == dns.rdatatype.DNAME: + raise ValueError(error) + async def _handle_udp( self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport ) -> None: @@ -580,6 +605,7 @@ class AsyncDnsServer(AsyncServer): peer = Peer(addr[0], addr[1]) responses = self._handle_query(wire, peer, DnsProtocol.UDP) async for response in responses: + logging.debug("Sending UDP message: %s", response.hex()) transport.sendto(response, addr) async def _handle_tcp( @@ -672,6 +698,7 @@ class AsyncDnsServer(AsyncServer): ) -> None: responses = self._handle_query(wire, peer, DnsProtocol.TCP) async for response in responses: + logging.debug("Sending TCP response: %s", response.hex()) writer.write(response) await writer.drain() @@ -807,23 +834,28 @@ class AsyncDnsServer(AsyncServer): if self._nxdomain_response(qctx): return + if self._cname_response(qctx): + return + if self._nodata_response(qctx): return self._noerror_response(qctx) def _refused_response(self, qctx: QueryContext) -> bool: - qctx.zone = self._zone_tree.find_best_zone(qctx.qname) - if qctx.zone: + zone = self._zone_tree.find_best_zone(qctx.current_qname) + if zone: + qctx.zone = zone return False - qctx.response.set_rcode(dns.rcode.REFUSED) + if not qctx.response.answer: + qctx.response.set_rcode(dns.rcode.REFUSED) return True def _delegation_response(self, qctx: QueryContext) -> bool: assert qctx.zone - name = qctx.qname + name = qctx.current_qname delegation = None while name != qctx.zone.origin: @@ -868,9 +900,9 @@ class AsyncDnsServer(AsyncServer): qctx.soa = qctx.zone.find_rrset(qctx.zone.origin, dns.rdatatype.SOA) assert qctx.soa - qctx.node = qctx.zone.get_node(qctx.qname) + qctx.node = qctx.zone.get_node(qctx.current_qname) if qctx.node or not any( - n for n in qctx.zone.nodes if n.is_subdomain(qctx.qname) + n for n in qctx.zone.nodes if n.is_subdomain(qctx.current_qname) ): return False @@ -888,6 +920,21 @@ class AsyncDnsServer(AsyncServer): qctx.response.authority.append(qctx.soa) return True + def _cname_response(self, qctx: QueryContext) -> bool: + assert qctx.node + + cname = qctx.node.get_rdataset(qctx.qclass, dns.rdatatype.CNAME) + if not cname: + return False + + cname_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, cname.rdtype) + cname_rrset.update(cname) + qctx.response.answer.append(cname_rrset) + + qctx.alias = cname[0].target + self._prepare_response_from_zone_data(qctx) + return True + def _nodata_response(self, qctx: QueryContext) -> bool: assert qctx.node assert qctx.soa @@ -897,13 +944,14 @@ class AsyncDnsServer(AsyncServer): return False qctx.response.set_rcode(dns.rcode.NOERROR) - qctx.response.authority.append(qctx.soa) + if not qctx.response.answer: + qctx.response.authority.append(qctx.soa) return True def _noerror_response(self, qctx: QueryContext) -> None: assert qctx.answer - answer_rrset = dns.rrset.RRset(qctx.qname, qctx.qclass, qctx.qtype) + answer_rrset = dns.rrset.RRset(qctx.current_qname, qctx.qclass, qctx.qtype) answer_rrset.update(qctx.answer) qctx.response.set_rcode(dns.rcode.NOERROR)