mirror of
https://gitlab.isc.org/isc-projects/bind9
synced 2025-08-30 05:57:52 +00:00
Simplify peer address formatting
Add a helper class, Peer, which holds the <host, port> tuple of a connection endpoint and gets pretty-printed when formatted as a string. This enables passing instances of this new class directly to logging functions, eliminating the need for the AsyncDnsServer._format_peer() helper method.
This commit is contained in:
parent
a2042e603e
commit
5764a9d660
@ -224,6 +224,20 @@ class DnsProtocol(enum.Enum):
|
||||
TCP = enum.auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Peer:
|
||||
"""
|
||||
Pretty-printed connection endpoint.
|
||||
"""
|
||||
|
||||
host: str
|
||||
port: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
host = f"[{self.host}]" if ":" in self.host else self.host
|
||||
return f"{host}:{self.port}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryContext:
|
||||
"""
|
||||
@ -232,7 +246,7 @@ class QueryContext:
|
||||
|
||||
query: dns.message.Message
|
||||
response: dns.message.Message
|
||||
peer: Tuple[str, int]
|
||||
peer: Peer
|
||||
protocol: DnsProtocol
|
||||
zone: Optional[dns.zone.Zone] = None
|
||||
soa: Optional[dns.rrset.RRset] = None
|
||||
@ -513,16 +527,20 @@ class AsyncDnsServer(AsyncServer):
|
||||
self._zone_tree.add(zone)
|
||||
|
||||
async def _handle_udp(
|
||||
self, wire: bytes, peer: Tuple[str, int], transport: asyncio.DatagramTransport
|
||||
self, wire: bytes, addr: Tuple[str, int], transport: asyncio.DatagramTransport
|
||||
) -> None:
|
||||
logging.debug("Received UDP message: %s", wire.hex())
|
||||
peer = Peer(addr[0], addr[1])
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.UDP)
|
||||
async for response in responses:
|
||||
transport.sendto(response, peer)
|
||||
transport.sendto(response, addr)
|
||||
|
||||
async def _handle_tcp(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
peer_info = writer.get_extra_info("peername")
|
||||
peer = Peer(peer_info[0], peer_info[1])
|
||||
|
||||
wire_length_bytes = await reader.read(2)
|
||||
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
||||
logging.debug("Receiving TCP message (%d octets)...", wire_length)
|
||||
@ -531,38 +549,26 @@ class AsyncDnsServer(AsyncServer):
|
||||
full_message = wire_length_bytes + wire
|
||||
logging.debug("Received complete TCP message: %s", full_message.hex())
|
||||
|
||||
peer = writer.get_extra_info("peername")
|
||||
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
||||
async for response in responses:
|
||||
writer.write(response)
|
||||
try:
|
||||
await writer.drain()
|
||||
except ConnectionResetError:
|
||||
logging.error(
|
||||
"TCP connection from %s reset by peer", self._format_peer(peer)
|
||||
)
|
||||
logging.error("TCP connection from %s reset by peer", peer)
|
||||
return
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
def _format_peer(self, peer: Tuple[str, int]) -> str:
|
||||
host = peer[0]
|
||||
port = peer[1]
|
||||
if "::" in host:
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
|
||||
def _log_query(
|
||||
self, qctx: QueryContext, peer: Tuple[str, int], protocol: DnsProtocol
|
||||
) -> None:
|
||||
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
|
||||
logging.info(
|
||||
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
||||
qctx.qname.to_text(omit_final_dot=True),
|
||||
dns.rdataclass.to_text(qctx.qclass),
|
||||
dns.rdatatype.to_text(qctx.qtype),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
@ -573,14 +579,14 @@ class AsyncDnsServer(AsyncServer):
|
||||
self,
|
||||
qctx: QueryContext,
|
||||
response: Optional[Union[dns.message.Message, bytes]],
|
||||
peer: Tuple[str, int],
|
||||
peer: Peer,
|
||||
protocol: DnsProtocol,
|
||||
) -> None:
|
||||
if not response:
|
||||
logging.info(
|
||||
"Not sending a response to query (ID=%d) from %s (%s)",
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
return
|
||||
@ -606,7 +612,7 @@ class AsyncDnsServer(AsyncServer):
|
||||
len(response.authority),
|
||||
len(response.additional),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug(
|
||||
@ -618,13 +624,13 @@ class AsyncDnsServer(AsyncServer):
|
||||
"Sending response (%d bytes) to a query (ID=%d) from %s (%s)",
|
||||
len(response),
|
||||
qctx.query.id,
|
||||
self._format_peer(peer),
|
||||
peer,
|
||||
protocol.name,
|
||||
)
|
||||
logging.debug("[OUT] %s", response.hex())
|
||||
|
||||
async def _handle_query(
|
||||
self, wire: bytes, peer: Tuple[str, int], protocol: DnsProtocol
|
||||
self, wire: bytes, peer: Peer, protocol: DnsProtocol
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
Yield wire data to send as a response over the established transport.
|
||||
|
Loading…
x
Reference in New Issue
Block a user