2
0
mirror of https://gitlab.isc.org/isc-projects/bind9 synced 2025-08-31 22:45:39 +00:00

Simplify peer address formatting

Add a helper class, Peer, which holds the <host, port> tuple of a
connection endpoint and gets pretty-printed when formatted as a string.
This enables passing instances of this new class directly to logging
functions, eliminating the need for the AsyncDnsServer._format_peer()
helper method.
This commit is contained in:
Michał Kępień
2025-03-18 16:28:18 +01:00
parent a2042e603e
commit 5764a9d660

View File

@@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum):
TCP = enum.auto() TCP = enum.auto()
@dataclass(frozen=True)
class Peer:
"""
Pretty-printed connection endpoint.
"""
host: str
port: int
def __str__(self) -> str:
host = f"[{self.host}]" if ":" in self.host else self.host
return f"{host}:{self.port}"
@dataclass @dataclass
class QueryContext: class QueryContext:
""" """
@@ -232,7 +246,7 @@ class QueryContext:
query: dns.message.Message query: dns.message.Message
response: dns.message.Message response: dns.message.Message
peer: Tuple[str, int] peer: Peer
protocol: DnsProtocol protocol: DnsProtocol
zone: Optional[dns.zone.Zone] = None zone: Optional[dns.zone.Zone] = None
soa: Optional[dns.rrset.RRset] = None soa: Optional[dns.rrset.RRset] = None
@@ -513,16 +527,20 @@ class AsyncDnsServer(AsyncServer):
self._zone_tree.add(zone) self._zone_tree.add(zone)
async def _handle_udp( async def _handle_udp(
self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
) -> None: ) -> None:
logging.debug("Received UDP message: %s", wire.hex()) logging.debug("Received UDP message: %s", wire.hex())
peer = Peer(addr[0], addr[1])
responses = self._handle_query(wire, peer, DnsProtocol.UDP) responses = self._handle_query(wire, peer, DnsProtocol.UDP)
async for response in responses: async for response in responses:
transport.sendto(response, peer) transport.sendto(response, addr)
async def _handle_tcp( async def _handle_tcp(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None: ) -> None:
peer_info = writer.get_extra_info("peername")
peer = Peer(peer_info[0], peer_info[1])
wire_length_bytes = await reader.read(2) wire_length_bytes = await reader.read(2)
(wire_length,) = struct.unpack("!H", wire_length_bytes) (wire_length,) = struct.unpack("!H", wire_length_bytes)
logging.debug("Receiving TCP message (%d octets)...", wire_length) logging.debug("Receiving TCP message (%d octets)...", wire_length)
@@ -531,38 +549,26 @@ class AsyncDnsServer(AsyncServer):
full_message = wire_length_bytes + wire full_message = wire_length_bytes + wire
logging.debug("Received complete TCP message: %s", full_message.hex()) logging.debug("Received complete TCP message: %s", full_message.hex())
peer = writer.get_extra_info("peername")
responses = self._handle_query(wire, peer, DnsProtocol.TCP) responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses: async for response in responses:
writer.write(response) writer.write(response)
try: try:
await writer.drain() await writer.drain()
except ConnectionResetError: except ConnectionResetError:
logging.error( logging.error("TCP connection from %s reset by peer", peer)
"TCP connection from %s reset by peer", self._format_peer(peer)
)
return return
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
def _format_peer(self, peer: Tuple[str, int]) -> str: def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
host = peer[0]
port = peer[1]
if "::" in host:
host = f"[{host}]"
return f"{host}:{port}"
def _log_query(
self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
) -> None:
logging.info( logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)", "Received %s/%s/%s (ID=%d) query from %s (%s)",
qctx.qname.to_text(omit_final_dot=True), qctx.qname.to_text(omit_final_dot=True),
dns.rdataclass.to_text(qctx.qclass), dns.rdataclass.to_text(qctx.qclass),
dns.rdatatype.to_text(qctx.qtype), dns.rdatatype.to_text(qctx.qtype),
qctx.query.id, qctx.query.id,
self._format_peer(peer), peer,
protocol.name, protocol.name,
) )
logging.debug( logging.debug(
@@ -573,14 +579,14 @@ class AsyncDnsServer(AsyncServer):
self, self,
qctx: QueryContext, qctx: QueryContext,
response: Optional[Union[dns.message.Message, bytes]], response: Optional[Union[dns.message.Message, bytes]],
peer: Tuple[str, int], peer: Peer,
protocol: DnsProtocol, protocol: DnsProtocol,
) -> None: ) -> None:
if not response: if not response:
logging.info( logging.info(
"Not sending a response to query (ID=%d) from %s (%s)", "Not sending a response to query (ID=%d) from %s (%s)",
qctx.query.id, qctx.query.id,
self._format_peer(peer), peer,
protocol.name, protocol.name,
) )
return return
@@ -606,7 +612,7 @@ class AsyncDnsServer(AsyncServer):
len(response.authority), len(response.authority),
len(response.additional), len(response.additional),
qctx.query.id, qctx.query.id,
self._format_peer(peer), peer,
protocol.name, protocol.name,
) )
logging.debug( logging.debug(
@@ -618,13 +624,13 @@ class AsyncDnsServer(AsyncServer):
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)", "Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
len(response), len(response),
qctx.query.id, qctx.query.id,
self._format_peer(peer), peer,
protocol.name, protocol.name,
) )
logging.debug("[OUT] %s", response.hex()) logging.debug("[OUT] %s", response.hex())
async def _handle_query( async def _handle_query(
self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol self, wire: bytes, peer: Peer, protocol: DnsProtocol
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
""" """
Yield wire data to send as a response over the established transport. Yield wire data to send as a response over the established transport.