mirror of
https://gitlab.isc.org/isc-projects/bind9
synced 2025-08-22 01:59:26 +00:00
Enforcing pylint standards and default for our test code seems counter-productive. Since most of the newly added code are tests or is test-related, encountering these checks rarely make us refactor the code in other ways and we just disable these checks individually. Code that is too complex or convoluted will be pointed out in reviews anyways.
798 lines
25 KiB
Python
798 lines
25 KiB
Python
"""
|
|
Copyright (C) Internet Systems Consortium, Inc. ("ISC")
|
|
|
|
SPDX-License-Identifier: MPL-2.0
|
|
|
|
This Source Code Form is subject to the terms of the Mozilla Public
|
|
License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
file, you can obtain one at https://mozilla.org/MPL/2.0/.
|
|
|
|
See the COPYRIGHT file distributed with this work for additional
|
|
information regarding copyright ownership.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
Callable,
|
|
Coroutine,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import abc
|
|
import asyncio
|
|
import enum
|
|
import functools
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import signal
|
|
import struct
|
|
import sys
|
|
|
|
import dns.flags
|
|
import dns.message
|
|
import dns.name
|
|
import dns.node
|
|
import dns.rcode
|
|
import dns.rdataclass
|
|
import dns.rdatatype
|
|
import dns.rrset
|
|
import dns.zone
|
|
|
|
try:
|
|
RdataType = dns.rdatatype.RdataType
|
|
RdataClass = dns.rdataclass.RdataClass
|
|
except AttributeError: # dnspython < 2.0.0 compat
|
|
RdataType = int # type: ignore
|
|
RdataClass = int # type: ignore
|
|
|
|
|
|
_UdpHandler = Callable[
|
|
[bytes, Tuple[str, int], asyncio.DatagramTransport], Coroutine[Any, Any, None]
|
|
]
|
|
|
|
|
|
_TcpHandler = Callable[
|
|
[asyncio.StreamReader, asyncio.StreamWriter], Coroutine[Any, Any, None]
|
|
]
|
|
|
|
|
|
class _AsyncUdpHandler(asyncio.DatagramProtocol):
|
|
"""
|
|
Protocol implementation for handling UDP traffic using asyncio.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
handler: _UdpHandler,
|
|
) -> None:
|
|
self._transport: Optional[asyncio.DatagramTransport] = None
|
|
self._handler: _UdpHandler = handler
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
"""
|
|
Called by asyncio when a connection is made.
|
|
"""
|
|
self._transport = cast(asyncio.DatagramTransport, transport)
|
|
|
|
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
|
|
"""
|
|
Called by asyncio when a datagram is received.
|
|
"""
|
|
assert self._transport
|
|
handler_coroutine = self._handler(data, addr, self._transport)
|
|
try:
|
|
# Python >= 3.7
|
|
asyncio.create_task(handler_coroutine)
|
|
except AttributeError:
|
|
# Python < 3.7
|
|
loop = asyncio.get_event_loop()
|
|
loop.create_task(handler_coroutine)
|
|
|
|
|
|
class AsyncServer:
|
|
"""
|
|
A generic asynchronous server which may handle UDP and/or TCP traffic.
|
|
|
|
Once the server is executed as asyncio coroutine, it will keep running
|
|
until a SIGINT/SIGTERM signal is received.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
udp_handler: Optional[_UdpHandler],
|
|
tcp_handler: Optional[_TcpHandler],
|
|
pidfile: Optional[str] = None,
|
|
) -> None:
|
|
logging.basicConfig(
|
|
format="%(asctime)s %(levelname)8s %(message)s",
|
|
level=os.environ.get("ANS_LOG_LEVEL", "INFO").upper(),
|
|
)
|
|
try:
|
|
ipv4_address = sys.argv[1]
|
|
except IndexError:
|
|
ipv4_address = self._get_ipv4_address_from_directory_name()
|
|
|
|
last_ipv4_address_octet = ipv4_address.split(".")[-1]
|
|
ipv6_address = f"fd92:7065:b8e:ffff::{last_ipv4_address_octet}"
|
|
|
|
try:
|
|
port = int(sys.argv[2])
|
|
except IndexError:
|
|
port = int(os.environ.get("PORT", 5300))
|
|
|
|
logging.info("Setting up IPv4 listener at %s:%d", ipv4_address, port)
|
|
logging.info("Setting up IPv6 listener at [%s]:%d", ipv6_address, port)
|
|
|
|
self._ip_addresses: Tuple[str, str] = (ipv4_address, ipv6_address)
|
|
self._port: int = port
|
|
self._udp_handler: Optional[_UdpHandler] = udp_handler
|
|
self._tcp_handler: Optional[_TcpHandler] = tcp_handler
|
|
self._pidfile: Optional[str] = pidfile
|
|
self._work_done: Optional[asyncio.Future] = None
|
|
|
|
def _get_ipv4_address_from_directory_name(self) -> str:
|
|
containing_directory = pathlib.Path().absolute().stem
|
|
match_result = re.match(r"ans(?P<index>\d+)", containing_directory)
|
|
if not match_result:
|
|
raise RuntimeError("Unable to auto-determine the IPv4 address to use")
|
|
|
|
return f"10.53.0.{match_result.group('index')}"
|
|
|
|
def run(self) -> None:
|
|
"""
|
|
Start the server in an asynchronous coroutine.
|
|
"""
|
|
coroutine = self._run
|
|
try:
|
|
# Python >= 3.7
|
|
asyncio.run(coroutine())
|
|
except AttributeError:
|
|
# Python < 3.7
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(coroutine())
|
|
|
|
async def _run(self) -> None:
|
|
self._setup_signals()
|
|
assert self._work_done
|
|
await self._listen_udp()
|
|
await self._listen_tcp()
|
|
self._write_pidfile()
|
|
await self._work_done
|
|
self._cleanup_pidfile()
|
|
|
|
def _get_asyncio_loop(self) -> asyncio.AbstractEventLoop:
|
|
try:
|
|
# Python >= 3.7
|
|
loop = asyncio.get_running_loop()
|
|
except AttributeError:
|
|
# Python < 3.7
|
|
loop = asyncio.get_event_loop()
|
|
return loop
|
|
|
|
def _setup_signals(self) -> None:
|
|
loop = self._get_asyncio_loop()
|
|
self._work_done = loop.create_future()
|
|
loop.add_signal_handler(signal.SIGINT, functools.partial(self._signal_done))
|
|
loop.add_signal_handler(signal.SIGTERM, functools.partial(self._signal_done))
|
|
|
|
def _signal_done(self) -> None:
|
|
assert self._work_done
|
|
self._work_done.set_result(True)
|
|
|
|
async def _listen_udp(self) -> None:
|
|
if not self._udp_handler:
|
|
return
|
|
loop = self._get_asyncio_loop()
|
|
for ip_address in self._ip_addresses:
|
|
await loop.create_datagram_endpoint(
|
|
lambda: _AsyncUdpHandler(cast(_UdpHandler, self._udp_handler)),
|
|
(ip_address, self._port),
|
|
)
|
|
|
|
async def _listen_tcp(self) -> None:
|
|
if not self._tcp_handler:
|
|
return
|
|
for ip_address in self._ip_addresses:
|
|
await asyncio.start_server(
|
|
self._tcp_handler, host=ip_address, port=self._port
|
|
)
|
|
|
|
def _write_pidfile(self) -> None:
|
|
if not self._pidfile:
|
|
return
|
|
logging.info("Writing PID to %s", self._pidfile)
|
|
with open(self._pidfile, "w", encoding="ascii") as pidfile:
|
|
print(f"{os.getpid()}", file=pidfile)
|
|
|
|
def _cleanup_pidfile(self) -> None:
|
|
if not self._pidfile:
|
|
return
|
|
logging.info("Removing %s", self._pidfile)
|
|
os.unlink(self._pidfile)
|
|
|
|
|
|
class DnsProtocol(enum.Enum):
|
|
UDP = enum.auto()
|
|
TCP = enum.auto()
|
|
|
|
|
|
@dataclass
|
|
class QueryContext:
|
|
"""
|
|
Context for the incoming query which may be used for preparing the response.
|
|
"""
|
|
|
|
query: dns.message.Message
|
|
response: dns.message.Message
|
|
peer: Tuple[str, int]
|
|
protocol: DnsProtocol
|
|
zone: Optional[dns.zone.Zone] = None
|
|
soa: Optional[dns.rrset.RRset] = None
|
|
node: Optional[dns.node.Node] = None
|
|
answer: Optional[dns.rdataset.Rdataset] = None
|
|
|
|
@property
|
|
def qname(self) -> dns.name.Name:
|
|
return self.query.question[0].name
|
|
|
|
@property
|
|
def qclass(self) -> RdataClass:
|
|
return self.query.question[0].rdclass
|
|
|
|
@property
|
|
def qtype(self) -> RdataType:
|
|
return self.query.question[0].rdtype
|
|
|
|
|
|
@dataclass
|
|
class ResponseAction(abc.ABC):
|
|
"""
|
|
Base class for actions that can be taken in response to a query.
|
|
"""
|
|
|
|
@abc.abstractmethod
|
|
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
|
|
"""
|
|
This method is expected to carry out arbitrary actions (e.g. wait for a
|
|
specific amount of time, modify the answer, etc.) and then return the
|
|
DNS response to send (a dns.message.Message, a raw bytes object, or
|
|
None, which prevents any response from being sent).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class DnsResponseSend(ResponseAction):
|
|
"""
|
|
Action which yields a dns.message.Message response.
|
|
|
|
The response may be sent with a delay if requested.
|
|
|
|
Depending on the value of the `authoritative` property, this class may set
|
|
the AA bit in the response (True), clear it (False), or not touch it at all
|
|
(None).
|
|
"""
|
|
|
|
response: dns.message.Message
|
|
authoritative: Optional[bool] = None
|
|
delay: float = 0.0
|
|
|
|
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
|
|
"""
|
|
Yield a potentially delayed response that is a dns.message.Message.
|
|
"""
|
|
assert isinstance(self.response, dns.message.Message)
|
|
if self.authoritative is not None:
|
|
if self.authoritative:
|
|
self.response.flags |= dns.flags.AA
|
|
else:
|
|
self.response.flags &= ~dns.flags.AA
|
|
if self.delay > 0:
|
|
logging.info(
|
|
"Delaying response (ID=%d) by %d ms",
|
|
self.response.id,
|
|
self.delay * 1000,
|
|
)
|
|
await asyncio.sleep(self.delay)
|
|
return self.response
|
|
|
|
|
|
@dataclass
|
|
class BytesResponseSend(ResponseAction):
|
|
"""
|
|
Action which yields a raw response that is a sequence of bytes.
|
|
|
|
The response may be sent with a delay if requested.
|
|
"""
|
|
|
|
response: bytes
|
|
delay: float = 0.0
|
|
|
|
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
|
|
"""
|
|
Yield a potentially delayed response that is a sequence of bytes.
|
|
"""
|
|
assert isinstance(self.response, bytes)
|
|
if self.delay > 0:
|
|
logging.info("Delaying raw response by %d ms", self.delay * 1000)
|
|
await asyncio.sleep(self.delay)
|
|
return self.response
|
|
|
|
|
|
@dataclass
|
|
class ResponseDrop(ResponseAction):
|
|
"""
|
|
Action which does nothing - as if a packet was dropped.
|
|
"""
|
|
|
|
async def perform(self) -> Optional[Union[dns.message.Message, bytes]]:
|
|
return None
|
|
|
|
|
|
class ResponseHandler(abc.ABC):
|
|
"""
|
|
Base class for generic response handlers.
|
|
|
|
If a query passes the `match()` function logic, then it is handled by this
|
|
response handler and response(s) may be generated by the `get_responses()`
|
|
method.
|
|
"""
|
|
|
|
@abc.abstractmethod
|
|
def match(self, qctx: QueryContext) -> bool:
|
|
"""
|
|
Matching logic - query is handled when it returns True.
|
|
"""
|
|
return True
|
|
|
|
@abc.abstractmethod
|
|
async def get_responses(
|
|
self, qctx: QueryContext
|
|
) -> AsyncGenerator[ResponseAction, None]:
|
|
"""
|
|
Custom handler which may produce response(s) to matching queries.
|
|
|
|
The response prepared from zone data is passed to this method in
|
|
qctx.response.
|
|
"""
|
|
yield DnsResponseSend(qctx.response)
|
|
|
|
|
|
class DomainHandler(ResponseHandler):
|
|
"""
|
|
Base class used for deriving custom domain handlers.
|
|
|
|
The derived class must specify a list of `domains` that it wants to handle.
|
|
Queries for any of these domains (and their subdomains) will then be passed
|
|
to the `get_response()` method in the derived class.
|
|
"""
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def domains(self) -> List[str]:
|
|
"""
|
|
A list of domain names handled by this class.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __init__(self) -> None:
|
|
self._domains: List[dns.name.Name] = [
|
|
dns.name.from_text(d) for d in self.domains
|
|
]
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.__class__.__name__}(domains: {', '.join(self.domains)})"
|
|
|
|
def match(self, qctx: QueryContext) -> bool:
|
|
"""
|
|
Handle queries whose QNAME matches any of the domains handled by this
|
|
class.
|
|
"""
|
|
for domain in self._domains:
|
|
if qctx.qname.is_subdomain(domain):
|
|
return True
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class _ZoneTreeNode:
|
|
"""
|
|
A node representing a zone with one origin.
|
|
"""
|
|
|
|
zone: Optional[dns.zone.Zone]
|
|
children: List["_ZoneTreeNode"] = field(default_factory=list)
|
|
|
|
|
|
class _ZoneTree:
|
|
"""
|
|
Tree with independent zones.
|
|
|
|
This zone tree is used as a backing structure for the DNS server. The
|
|
individual zones are independent to allow the (single) server to serve both
|
|
the parent zone and a child zone if needed.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._root: _ZoneTreeNode = _ZoneTreeNode(None)
|
|
|
|
def add(self, zone: dns.zone.Zone) -> None:
|
|
"""
|
|
Add a zone to the tree and rearrange sub-zones if necessary.
|
|
"""
|
|
assert zone.origin
|
|
best_match = self._find_best_match(zone.origin, self._root)
|
|
added_node = _ZoneTreeNode(zone)
|
|
self._move_children(best_match, added_node)
|
|
best_match.children.append(added_node)
|
|
|
|
def _find_best_match(
|
|
self, name: dns.name.Name, start_node: _ZoneTreeNode
|
|
) -> _ZoneTreeNode:
|
|
for child in start_node.children:
|
|
assert child.zone
|
|
assert child.zone.origin
|
|
if name.is_subdomain(child.zone.origin):
|
|
return self._find_best_match(name, child)
|
|
return start_node
|
|
|
|
def _move_children(self, node_from: _ZoneTreeNode, node_to: _ZoneTreeNode) -> None:
|
|
assert node_to.zone
|
|
assert node_to.zone.origin
|
|
|
|
children_to_move = []
|
|
for child in node_from.children:
|
|
assert child.zone
|
|
assert child.zone.origin
|
|
if child.zone.origin.is_subdomain(node_to.zone.origin):
|
|
children_to_move.append(child)
|
|
|
|
for child in children_to_move:
|
|
node_from.children.remove(child)
|
|
node_to.children.append(child)
|
|
|
|
def find_best_zone(self, name: dns.name.Name) -> Optional[dns.zone.Zone]:
|
|
"""
|
|
Return the closest matching zone (if any) for the domain name.
|
|
"""
|
|
node = self._find_best_match(name, self._root)
|
|
return node.zone if node != self._root else None
|
|
|
|
|
|
class AsyncDnsServer(AsyncServer):
|
|
"""
|
|
DNS server which responds to queries based on zone data and/or custom
|
|
handlers.
|
|
|
|
The server may use custom handlers which allow arbitrary query processing.
|
|
These don't need to be standards-compliant and can be used for testing all
|
|
sorts of scenarios, including delaying responses, synthesizing them based
|
|
on query contents etc.
|
|
|
|
The server also loads any zone files (*.db) found in its directory and
|
|
serves them. Responses prepared using zone data can then be modified,
|
|
replaced, or suppressed by query handlers. Query handlers can also generate
|
|
response from scratch, without using zone data at all.
|
|
"""
|
|
|
|
def __init__(self, load_zones: bool = True):
|
|
super().__init__(self._handle_udp, self._handle_tcp, "ans.pid")
|
|
|
|
self._zone_tree: _ZoneTree = _ZoneTree()
|
|
self._response_handlers: List[ResponseHandler] = []
|
|
|
|
if load_zones:
|
|
self._load_zones()
|
|
|
|
def install_response_handler(self, handler: ResponseHandler) -> None:
|
|
"""
|
|
Add a response handler which will be used to handle matching queries.
|
|
|
|
Response handlers can modify, replace, or suppress the answers prepared
|
|
from zone file contents.
|
|
"""
|
|
logging.info("Installing response handler: %s", handler)
|
|
self._response_handlers.append(handler)
|
|
|
|
def _load_zones(self) -> None:
|
|
for entry in os.scandir():
|
|
entry_path = pathlib.Path(entry.path)
|
|
if entry_path.suffix != ".db":
|
|
continue
|
|
origin = dns.name.from_text(entry_path.stem)
|
|
logging.info("Loading zone file %s", entry_path)
|
|
zone = dns.zone.from_file(entry.path, origin, relativize=False)
|
|
self._zone_tree.add(zone)
|
|
|
|
async def _handle_udp(
|
|
self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
|
|
) -> None:
|
|
logging.debug("Received UDP message: %s", wire.hex())
|
|
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
|
|
async for response in responses:
|
|
transport.sendto(response, peer)
|
|
|
|
async def _handle_tcp(
|
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
wire_length_bytes = await reader.read(2)
|
|
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
|
logging.debug("Receiving TCP message (%d octets)...", wire_length)
|
|
|
|
wire = await reader.read(wire_length)
|
|
full_message = wire_length_bytes + wire
|
|
logging.debug("Received complete TCP message: %s", full_message.hex())
|
|
|
|
peer = writer.get_extra_info("peername")
|
|
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
|
async for response in responses:
|
|
writer.write(response)
|
|
try:
|
|
await writer.drain()
|
|
except ConnectionResetError:
|
|
logging.error(
|
|
"TCP connection from %s reset by peer", self._format_peer(peer)
|
|
)
|
|
return
|
|
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
|
|
def _format_peer(self, peer: Tuple[str, int]) -> str:
|
|
host = peer[0]
|
|
port = peer[1]
|
|
if "::" in host:
|
|
host = f"[{host}]"
|
|
return f"{host}:{port}"
|
|
|
|
def _log_query(
|
|
self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
|
|
) -> None:
|
|
logging.info(
|
|
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
|
qctx.qname.to_text(omit_final_dot=True),
|
|
dns.rdataclass.to_text(qctx.qclass),
|
|
dns.rdatatype.to_text(qctx.qtype),
|
|
qctx.query.id,
|
|
self._format_peer(peer),
|
|
protocol.name,
|
|
)
|
|
logging.debug(
|
|
"\n".join([f"[IN] {l}" for l in [""] + str(qctx.query).splitlines()])
|
|
)
|
|
|
|
def _log_response(
|
|
self,
|
|
qctx: QueryContext,
|
|
response: Optional[Union[dns.message.Message, bytes]],
|
|
peer: Tuple[str, int],
|
|
protocol: DnsProtocol,
|
|
) -> None:
|
|
if not response:
|
|
logging.info(
|
|
"Not sending a response to query (ID=%d) from %s (%s)",
|
|
qctx.query.id,
|
|
self._format_peer(peer),
|
|
protocol.name,
|
|
)
|
|
return
|
|
|
|
if isinstance(response, dns.message.Message):
|
|
try:
|
|
qname = response.question[0].name.to_text(omit_final_dot=True)
|
|
qclass = dns.rdataclass.to_text(response.question[0].rdclass)
|
|
qtype = dns.rdatatype.to_text(response.question[0].rdtype)
|
|
except IndexError:
|
|
qname = "<empty>"
|
|
qclass = "-"
|
|
qtype = "-"
|
|
|
|
logging.info(
|
|
"Sending %s/%s/%s (ID=%d) response (%d/%d/%d/%d) to a query (ID=%d) from %s (%s)",
|
|
qname,
|
|
qclass,
|
|
qtype,
|
|
response.id,
|
|
len(response.question),
|
|
len(response.answer),
|
|
len(response.authority),
|
|
len(response.additional),
|
|
qctx.query.id,
|
|
self._format_peer(peer),
|
|
protocol.name,
|
|
)
|
|
logging.debug(
|
|
"\n".join([f"[OUT] {l}" for l in [""] + str(response).splitlines()])
|
|
)
|
|
return
|
|
|
|
logging.info(
|
|
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
|
|
len(response),
|
|
qctx.query.id,
|
|
self._format_peer(peer),
|
|
protocol.name,
|
|
)
|
|
logging.debug("[OUT] %s", response.hex())
|
|
|
|
async def _handle_query(
|
|
self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
|
|
) -> AsyncGenerator[bytes, None]:
|
|
"""
|
|
Yield wire data to send as a response over the established transport.
|
|
"""
|
|
query = dns.message.from_wire(wire)
|
|
response_stub = dns.message.make_response(query)
|
|
qctx = QueryContext(query, response_stub, peer, protocol)
|
|
self._log_query(qctx, peer, protocol)
|
|
responses = self._prepare_responses(qctx)
|
|
async for response in responses:
|
|
self._log_response(qctx, response, peer, protocol)
|
|
if response:
|
|
if isinstance(response, dns.message.Message):
|
|
response = response.to_wire(max_size=65535)
|
|
if protocol == DnsProtocol.UDP:
|
|
yield response
|
|
else:
|
|
response_length = struct.pack("!H", len(response))
|
|
yield response_length + response
|
|
|
|
async def _prepare_responses(
|
|
self, qctx: QueryContext
|
|
) -> AsyncGenerator[Optional[Union[dns.message.Message, bytes]], None]:
|
|
"""
|
|
Yield response(s) either from response handlers or zone data.
|
|
"""
|
|
self._prepare_response_from_zone_data(qctx)
|
|
|
|
response_handled = False
|
|
async for action in self._run_response_handlers(qctx):
|
|
yield await action.perform()
|
|
response_handled = True
|
|
|
|
if not response_handled:
|
|
yield qctx.response
|
|
|
|
def _prepare_response_from_zone_data(self, qctx: QueryContext) -> None:
|
|
"""
|
|
Prepare a response to the query based on the available zone data.
|
|
|
|
The functionality is split across smaller functions that modify the
|
|
query context until a proper response is formed.
|
|
"""
|
|
if self._refused_response(qctx):
|
|
return
|
|
|
|
if self._delegation_response(qctx):
|
|
return
|
|
|
|
qctx.response.flags |= dns.flags.AA
|
|
|
|
if self._ent_response(qctx):
|
|
return
|
|
|
|
if self._nxdomain_response(qctx):
|
|
return
|
|
|
|
if self._nodata_response(qctx):
|
|
return
|
|
|
|
self._noerror_response(qctx)
|
|
|
|
def _refused_response(self, qctx: QueryContext) -> bool:
|
|
qctx.zone = self._zone_tree.find_best_zone(qctx.qname)
|
|
if qctx.zone:
|
|
return False
|
|
|
|
qctx.response.set_rcode(dns.rcode.REFUSED)
|
|
return True
|
|
|
|
def _delegation_response(self, qctx: QueryContext) -> bool:
|
|
assert qctx.zone
|
|
|
|
name = qctx.qname
|
|
delegation = None
|
|
|
|
while name != qctx.zone.origin:
|
|
node = qctx.zone.get_node(name)
|
|
if node:
|
|
delegation = node.get_rdataset(qctx.qclass, dns.rdatatype.NS)
|
|
if delegation:
|
|
break
|
|
name = name.parent()
|
|
|
|
if not delegation:
|
|
return False
|
|
|
|
delegation_rrset = dns.rrset.RRset(name, qctx.qclass, dns.rdatatype.NS)
|
|
delegation_rrset.update(delegation)
|
|
|
|
qctx.response.set_rcode(dns.rcode.NOERROR)
|
|
qctx.response.authority.append(delegation_rrset)
|
|
|
|
self._delegation_response_additional(qctx)
|
|
|
|
return True
|
|
|
|
def _delegation_response_additional(self, qctx: QueryContext) -> None:
|
|
assert qctx.zone
|
|
assert qctx.response.authority[0]
|
|
|
|
for nameserver in qctx.response.authority[0]:
|
|
if not nameserver.target.is_subdomain(qctx.response.authority[0].name):
|
|
continue
|
|
glue_a = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.A)
|
|
if glue_a:
|
|
qctx.response.additional.append(glue_a)
|
|
glue_aaaa = qctx.zone.get_rrset(nameserver.target, dns.rdatatype.AAAA)
|
|
if glue_aaaa:
|
|
qctx.response.additional.append(glue_aaaa)
|
|
|
|
def _ent_response(self, qctx: QueryContext) -> bool:
|
|
assert qctx.zone
|
|
assert qctx.zone.origin
|
|
|
|
qctx.soa = qctx.zone.find_rrset(qctx.zone.origin, dns.rdatatype.SOA)
|
|
assert qctx.soa
|
|
|
|
qctx.node = qctx.zone.get_node(qctx.qname)
|
|
if qctx.node or not any(
|
|
n for n in qctx.zone.nodes if n.is_subdomain(qctx.qname)
|
|
):
|
|
return False
|
|
|
|
qctx.response.set_rcode(dns.rcode.NOERROR)
|
|
qctx.response.authority.append(qctx.soa)
|
|
return True
|
|
|
|
def _nxdomain_response(self, qctx: QueryContext) -> bool:
|
|
assert qctx.soa
|
|
|
|
if qctx.node:
|
|
return False
|
|
|
|
qctx.response.set_rcode(dns.rcode.NXDOMAIN)
|
|
qctx.response.authority.append(qctx.soa)
|
|
return True
|
|
|
|
def _nodata_response(self, qctx: QueryContext) -> bool:
|
|
assert qctx.node
|
|
assert qctx.soa
|
|
|
|
qctx.answer = qctx.node.get_rdataset(qctx.qclass, qctx.qtype)
|
|
if qctx.answer:
|
|
return False
|
|
|
|
qctx.response.set_rcode(dns.rcode.NOERROR)
|
|
qctx.response.authority.append(qctx.soa)
|
|
return True
|
|
|
|
def _noerror_response(self, qctx: QueryContext) -> None:
|
|
assert qctx.answer
|
|
|
|
answer_rrset = dns.rrset.RRset(qctx.qname, qctx.qclass, qctx.qtype)
|
|
answer_rrset.update(qctx.answer)
|
|
|
|
qctx.response.set_rcode(dns.rcode.NOERROR)
|
|
qctx.response.answer.append(answer_rrset)
|
|
|
|
async def _run_response_handlers(
|
|
self, qctx: QueryContext
|
|
) -> AsyncGenerator[ResponseAction, None]:
|
|
"""
|
|
Yield response(s) to the query from a matching query handler.
|
|
"""
|
|
for handler in self._response_handlers:
|
|
if handler.match(qctx):
|
|
async for response in handler.get_responses(qctx):
|
|
yield response
|
|
return
|