2
0
mirror of https://gitlab.isc.org/isc-projects/bind9 synced 2025-08-31 22:45:39 +00:00

Replace dns.query module with isctest.query

This commit is contained in:
Michal Nowak
2024-09-27 13:35:56 +02:00
parent a2d2d9c0d3
commit dfec69b4a2
7 changed files with 50 additions and 102 deletions

View File

@@ -24,7 +24,6 @@ pytest.importorskip("dns", minversion="2.0.0")
import dns.exception import dns.exception
import dns.message import dns.message
import dns.name import dns.name
import dns.query
import dns.rcode import dns.rcode
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
@@ -61,16 +60,9 @@ def has_signed_apex_nsec(zone, response):
def do_query(server, qname, qtype, tcp=False): def do_query(server, qname, qtype, tcp=False):
query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True) msg = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True)
try: query_func = isctest.query.tcp if tcp else isctest.query.udp
if tcp: response = query_func(msg, server.ip, expected_rcode=dns.rcode.NOERROR)
response = dns.query.tcp(query, server.ip, timeout=3, port=server.ports.dns)
else:
response = dns.query.udp(query, server.ip, timeout=3, port=server.ports.dns)
except dns.exception.Timeout:
print(f"error: query timeout for query {qname} {qtype} to {server.ip}")
return None
return response return response
@@ -97,38 +89,26 @@ def verify_zone(zone, transfer):
def read_statefile(server, zone): def read_statefile(server, zone):
addr = server.ip
count = 0 count = 0
keyid = 0 keyid = 0
state = {} state = {}
response = do_query(server, zone, "DS", tcp=True) response = do_query(server, zone, "DS", tcp=True)
if not isinstance(response, dns.message.Message): # fetch key id from response.
print(f"error: no response for {zone} DS from {addr}") for rr in response.answer:
return {} if rr.match(
dns.name.from_text(zone),
dns.rdataclass.IN,
dns.rdatatype.DS,
dns.rdatatype.NONE,
):
if count == 0:
keyid = list(dict(rr.items).items())[0][0].key_tag
count += 1
if response.rcode() == dns.rcode.NOERROR: assert (
# fetch key id from response. count == 1
for rr in response.answer: ), f"expected a single DS in response for {zone} from {server.ip}, got {count}"
if rr.match(
dns.name.from_text(zone),
dns.rdataclass.IN,
dns.rdatatype.DS,
dns.rdatatype.NONE,
):
if count == 0:
keyid = list(dict(rr.items).items())[0][0].key_tag
count += 1
if count != 1:
print(
f"error: expected a single DS in response for {zone} from {addr}, got {count}"
)
return {}
else:
rcode = dns.rcode.to_text(response.rcode())
print(f"error: {rcode} response for {zone} DNSKEY from {addr}")
return {}
filename = f"ns9/K{zone}+013+{keyid:05d}.state" filename = f"ns9/K{zone}+013+{keyid:05d}.state"
print(f"read state file {filename}") print(f"read state file {filename}")
@@ -140,7 +120,6 @@ def read_statefile(server, zone):
continue continue
key, val = line.strip().split(":", 1) key, val = line.strip().split(":", 1)
state[key.strip()] = val.strip() state[key.strip()] = val.strip()
except FileNotFoundError: except FileNotFoundError:
# file may not be written just yet. # file may not be written just yet.
return {} return {}
@@ -149,40 +128,15 @@ def read_statefile(server, zone):
def zone_check(server, zone): def zone_check(server, zone):
addr = server.ip
fqdn = f"{zone}." fqdn = f"{zone}."
# wait until zone is fully signed. # check zone is fully signed.
signed = False response = do_query(server, fqdn, "NSEC")
for _ in range(10): assert has_signed_apex_nsec(fqdn, response)
response = do_query(server, fqdn, "NSEC")
if not isinstance(response, dns.message.Message):
print(f"error: no response for {fqdn} NSEC from {addr}")
elif response.rcode() == dns.rcode.NOERROR:
signed = has_signed_apex_nsec(fqdn, response)
else:
rcode = dns.rcode.to_text(response.rcode())
print(f"error: {rcode} response for {fqdn} NSEC from {addr}")
if signed:
break
time.sleep(1)
assert signed
# check if zone if DNSSEC valid. # check if zone if DNSSEC valid.
verified = False
transfer = do_query(server, fqdn, "AXFR", tcp=True) transfer = do_query(server, fqdn, "AXFR", tcp=True)
if not isinstance(transfer, dns.message.Message): assert verify_zone(fqdn, transfer)
print(f"error: no response for {fqdn} AXFR from {addr}")
elif transfer.rcode() == dns.rcode.NOERROR:
verified = verify_zone(fqdn, transfer)
else:
rcode = dns.rcode.to_text(transfer.rcode())
print(f"error: {rcode} response for {fqdn} AXFR from {addr}")
assert verified
def keystate_check(server, zone, key): def keystate_check(server, zone, key):

View File

@@ -12,16 +12,15 @@
# information regarding copyright ownership. # information regarding copyright ownership.
import pytest import pytest
import isctest
pytest.importorskip("dns") pytest.importorskip("dns")
import dns.message import dns.message
import dns.query
import dns.rcode
def test_connreset(named_port): def test_connreset():
msg = dns.message.make_query( msg = dns.message.make_query(
"sub.example.", "A", want_dnssec=True, use_edns=0, payload=1232 "sub.example.", "A", want_dnssec=True, use_edns=0, payload=1232
) )
ans = dns.query.udp(msg, "10.53.0.2", timeout=10, port=named_port) res = isctest.query.udp(msg, "10.53.0.2")
assert ans.rcode() == dns.rcode.SERVFAIL isctest.check.servfail(res)

View File

@@ -10,18 +10,14 @@
# information regarding copyright ownership. # information regarding copyright ownership.
import pytest import pytest
import isctest
pytest.importorskip("dns") pytest.importorskip("dns")
import dns.message import dns.message
import dns.query
import dns.rcode
def test_async_hook(named_port): def test_async_hook():
msg = dns.message.make_query( msg = dns.message.make_query("example.com.", "A")
"example.com.", res = isctest.query.udp(msg, "10.53.0.1")
"A",
)
ans = dns.query.udp(msg, "10.53.0.1", timeout=10, port=named_port)
# the test-async plugin changes the status of any positive answer to NOTIMP # the test-async plugin changes the status of any positive answer to NOTIMP
assert ans.rcode() == dns.rcode.NOTIMP isctest.check.notimp(res)

View File

@@ -15,14 +15,13 @@ import socket
import time import time
import pytest import pytest
import isctest
pytest.importorskip("dns") pytest.importorskip("dns")
import dns.message import dns.message
import dns.query
import dns.rcode
def test_cve_2023_3341(named_port, control_port): def test_cve_2023_3341(control_port):
depth = 4500 depth = 4500
# Should not be more than isccc_ccmsg_setmaxsize(&conn->ccmsg, 32768) # Should not be more than isccc_ccmsg_setmaxsize(&conn->ccmsg, 32768)
total_len = 10 + (depth * 7) - 6 total_len = 10 + (depth * 7) - 6
@@ -52,6 +51,7 @@ def test_cve_2023_3341(named_port, control_port):
# Wait for named to (possibly) crash # Wait for named to (possibly) crash
time.sleep(10) time.sleep(10)
msg = dns.message.make_query("version.bind", "TXT", "CH") msg = dns.message.make_query("version.bind", "TXT", "CH")
ans = dns.query.udp(msg, "10.53.0.2", timeout=10, port=named_port) res = isctest.query.udp(msg, "10.53.0.2")
assert ans.rcode() == dns.rcode.NOERROR isctest.check.noerror(res)

View File

@@ -13,11 +13,10 @@ import concurrent.futures
import os import os
import time import time
import dns.query
import dns.update
import isctest import isctest
import dns.update
def rndc_loop(test_state, server): def rndc_loop(test_state, server):
rndc = os.getenv("RNDC") rndc = os.getenv("RNDC")
@@ -39,7 +38,7 @@ def rndc_loop(test_state, server):
time.sleep(1) time.sleep(1)
def update_zone(test_state, zone, named_port): def update_zone(test_state, zone):
server = "10.53.0.2" server = "10.53.0.2"
for i in range(1000): for i in range(1000):
if test_state["finished"]: if test_state["finished"]:
@@ -47,7 +46,7 @@ def update_zone(test_state, zone, named_port):
update = dns.update.UpdateMessage(zone) update = dns.update.UpdateMessage(zone)
update.add(f"dynamic-{i}.{zone}", 300, "TXT", f"txt-{i}") update.add(f"dynamic-{i}.{zone}", 300, "TXT", f"txt-{i}")
try: try:
response = dns.query.udp(update, server, 10, named_port) response = isctest.query.udp(update, server)
assert response.rcode() == dns.rcode.NOERROR assert response.rcode() == dns.rcode.NOERROR
except dns.exception.Timeout: except dns.exception.Timeout:
isctest.log.info(f"error: query timeout for {zone}") isctest.log.info(f"error: query timeout for {zone}")
@@ -56,7 +55,7 @@ def update_zone(test_state, zone, named_port):
# If the test has run to completion without named crashing, it has succeeded. # If the test has run to completion without named crashing, it has succeeded.
def test_update_stress(named_port): def test_update_stress():
test_state = {"finished": False} test_state = {"finished": False}
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -65,7 +64,7 @@ def test_update_stress(named_port):
updaters = [] updaters = []
for i in range(5): for i in range(5):
zone = f"zone00000{i}.example." zone = f"zone00000{i}.example."
updaters.append(executor.submit(update_zone, test_state, zone, named_port)) updaters.append(executor.submit(update_zone, test_state, zone))
# All the update_zone() tasks are expected to complete within 5 # All the update_zone() tasks are expected to complete within 5
# minutes. If they do not, we cannot assert immediately as that will # minutes. If they do not, we cannot assert immediately as that will

View File

@@ -23,10 +23,11 @@ import time
import pytest import pytest
import isctest
pytest.importorskip("dns") pytest.importorskip("dns")
import dns.message import dns.message
import dns.name import dns.name
import dns.query
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
@@ -177,13 +178,13 @@ def send_crafted_tkey_query(opts: argparse.Namespace) -> None:
print(query.to_text()) print(query.to_text())
print() print()
response = dns.query.tcp(query, opts.server_ip, timeout=2, port=opts.server_port) response = isctest.query.tcp(query, opts.server_ip, timeout=2)
print("# < " + str(datetime.datetime.now())) print("# < " + str(datetime.datetime.now()))
print(response.to_text()) print(response.to_text())
print() print()
def test_cve_2020_8625(named_port): def test_cve_2020_8625():
""" """
Reproducer for CVE-2020-8625. When run for an affected BIND 9 version, Reproducer for CVE-2020-8625. When run for an affected BIND 9 version,
send_crafted_tkey_query() will raise a network-related exception due to send_crafted_tkey_query() will raise a network-related exception due to
@@ -192,14 +193,13 @@ def test_cve_2020_8625(named_port):
for i in range(0, 50): for i in range(0, 50):
opts = argparse.Namespace( opts = argparse.Namespace(
server_ip="10.53.0.1", server_ip="10.53.0.1",
server_port=named_port,
real_oid_length=i, real_oid_length=i,
extra_oid_length=0, extra_oid_length=0,
) )
send_crafted_tkey_query(opts) send_crafted_tkey_query(opts)
def test_cve_2021_25216(named_port): def test_cve_2021_25216():
""" """
Reproducer for CVE-2021-25216. When run for an affected BIND 9 version, Reproducer for CVE-2021-25216. When run for an affected BIND 9 version,
send_crafted_tkey_query() will raise a network-related exception due to send_crafted_tkey_query() will raise a network-related exception due to
@@ -207,7 +207,6 @@ def test_cve_2021_25216(named_port):
""" """
opts = argparse.Namespace( opts = argparse.Namespace(
server_ip="10.53.0.1", server_ip="10.53.0.1",
server_port=named_port,
real_oid_length=1, real_oid_length=1,
extra_oid_length=1073741824, extra_oid_length=1073741824,
) )

View File

@@ -11,9 +11,10 @@
import pytest import pytest
import isctest
pytest.importorskip("dns") pytest.importorskip("dns")
import dns.message import dns.message
import dns.query
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -25,8 +26,8 @@ import dns.query
("max-example.", "MX", 60), ("max-example.", "MX", 60),
], ],
) )
def test_cache_ttl(qname, rdtype, expected_ttl, named_port): def test_cache_ttl(qname, rdtype, expected_ttl):
msg = dns.message.make_query(qname, rdtype) msg = dns.message.make_query(qname, rdtype)
response = dns.query.udp(msg, "10.53.0.2", timeout=10, port=named_port) response = isctest.query.udp(msg, "10.53.0.2")
for rr in response.answer + response.authority: for rr in response.answer + response.authority:
assert rr.ttl == expected_ttl assert rr.ttl == expected_ttl