mirror of
https://gitlab.isc.org/isc-projects/bind9
synced 2025-08-29 13:38:26 +00:00
Refactor AsyncDnsServer._handle_tcp()
Split up AsyncDnsServer._handle_tcp() into a set of smaller methods to improve code readability.
This commit is contained in:
parent
e4c3186a7c
commit
a956947fba
@ -541,18 +541,49 @@ class AsyncDnsServer(AsyncServer):
|
|||||||
peer_info = writer.get_extra_info("peername")
|
peer_info = writer.get_extra_info("peername")
|
||||||
peer = Peer(peer_info[0], peer_info[1])
|
peer = Peer(peer_info[0], peer_info[1])
|
||||||
|
|
||||||
|
for _ in range(0, 1):
|
||||||
|
wire = await self._read_tcp_query(reader)
|
||||||
|
if not wire:
|
||||||
|
break
|
||||||
|
await self._send_tcp_response(writer, peer, wire)
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
|
||||||
|
async def _read_tcp_query(self, reader: asyncio.StreamReader) -> Optional[bytes]:
|
||||||
|
wire_length = await self._read_tcp_query_wire_length(reader)
|
||||||
|
if not wire_length:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self._read_tcp_query_wire(reader, wire_length)
|
||||||
|
|
||||||
|
async def _read_tcp_query_wire_length(
|
||||||
|
self, reader: asyncio.StreamReader
|
||||||
|
) -> Optional[int]:
|
||||||
wire_length_bytes = await reader.read(2)
|
wire_length_bytes = await reader.read(2)
|
||||||
if len(wire_length_bytes) < 2:
|
if len(wire_length_bytes) < 2:
|
||||||
return
|
return None
|
||||||
|
|
||||||
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
(wire_length,) = struct.unpack("!H", wire_length_bytes)
|
||||||
|
|
||||||
|
return wire_length
|
||||||
|
|
||||||
|
async def _read_tcp_query_wire(
|
||||||
|
self, reader: asyncio.StreamReader, wire_length: int
|
||||||
|
) -> Optional[bytes]:
|
||||||
logging.debug("Receiving TCP message (%d octets)...", wire_length)
|
logging.debug("Receiving TCP message (%d octets)...", wire_length)
|
||||||
|
|
||||||
wire = await reader.read(wire_length)
|
wire = await reader.read(wire_length)
|
||||||
if len(wire) < wire_length:
|
if len(wire) < wire_length:
|
||||||
return
|
return None
|
||||||
full_message = wire_length_bytes + wire
|
|
||||||
logging.debug("Received complete TCP message: %s", full_message.hex())
|
|
||||||
|
|
||||||
|
logging.debug("Received complete TCP message: %s", wire.hex())
|
||||||
|
|
||||||
|
return wire
|
||||||
|
|
||||||
|
async def _send_tcp_response(
|
||||||
|
self, writer: asyncio.StreamWriter, peer: Peer, wire: bytes
|
||||||
|
) -> None:
|
||||||
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
responses = self._handle_query(wire, peer, DnsProtocol.TCP)
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
writer.write(response)
|
writer.write(response)
|
||||||
@ -562,9 +593,6 @@ class AsyncDnsServer(AsyncServer):
|
|||||||
logging.error("TCP connection from %s reset by peer", peer)
|
logging.error("TCP connection from %s reset by peer", peer)
|
||||||
return
|
return
|
||||||
|
|
||||||
writer.close()
|
|
||||||
await writer.wait_closed()
|
|
||||||
|
|
||||||
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
|
def _log_query(self, qctx: QueryContext, peer: Peer, protocol: DnsProtocol) -> None:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
"Received %s/%s/%s (ID=%d) query from %s (%s)",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user