2
0
mirror of https://gitlab.isc.org/isc-projects/bind9 synced 2025-08-29 13:38:26 +00:00

Refactor AsyncDnsServer._handle_tcp()

Split up AsyncDnsServer._handle_tcp() into a set of smaller methods to
improve code readability.
This commit is contained in:
Michał Kępień 2025-03-18 16:28:18 +01:00
parent e4c3186a7c
commit a956947fba
No known key found for this signature in database

View File

@ -541,18 +541,49 @@ class AsyncDnsServer(AsyncServer):
peer_info = writer.get_extra_info("peername") peer_info = writer.get_extra_info("peername")
peer = Peer(peer_info[0], peer_info[1]) peer = Peer(peer_info[0], peer_info[1])
for _ in range(0, 1):
wire = await self._read_tcp_query(reader)
if not wire:
break
await self._send_tcp_response(writer, peer, wire)
writer.close()
await writer.wait_closed()
async def _read_tcp_query(self, reader: asyncio.StreamReader) -> Optional[bytes]:
wire_length = await self._read_tcp_query_wire_length(reader)
if not wire_length:
return None
return await self._read_tcp_query_wire(reader, wire_length)
async def _read_tcp_query_wire_length(
self, reader: asyncio.StreamReader
) -> Optional[int]:
wire_length_bytes = await reader.read(2) wire_length_bytes = await reader.read(2)
if len(wire_length_bytes) < 2: if len(wire_length_bytes) < 2:
return return None
(wire_length,) = struct.unpack("!H", wire_length_bytes) (wire_length,) = struct.unpack("!H", wire_length_bytes)
return wire_length
async def _read_tcp_query_wire(
self, reader: asyncio.StreamReader, wire_length: int
) -> Optional[bytes]:
logging.debug("Receiving TCP message (%d octets)...", wire_length) logging.debug("Receiving TCP message (%d octets)...", wire_length)
wire = await reader.read(wire_length) wire = await reader.read(wire_length)
if len(wire) < wire_length: if len(wire) < wire_length:
return return None
full_message = wire_length_bytes + wire
logging.debug("Received complete TCP message: %s", full_message.hex())
logging.debug("Received complete TCP message: %s", wire.hex())
return wire
async def _send_tcp_response(
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
) -> None:
responses = self._handle_query(wire, peer, DnsProtocol.TCP) responses = self._handle_query(wire, peer, DnsProtocol.TCP)
async for response in responses: async for response in responses:
writer.write(response) writer.write(response)
@ -562,9 +593,6 @@ class AsyncDnsServer(AsyncServer):
logging.error("TCP connection from %s reset by peer", peer) logging.error("TCP connection from %s reset by peer", peer)
return return
writer.close()
await writer.wait_closed()
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None: def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
logging.info( logging.info(
"Received %s/%s/%s (ID=%d) query from %s (%s)", "Received %s/%s/%s (ID=%d) query from %s (%s)",