2
0
mirror of https://gitlab.isc.org/isc-projects/bind9 synced 2025-08-22 18:19:42 +00:00

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

568 lines
18 KiB
Python
Raw Normal View History

# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
#
# SPDX-License-Identifier: MPL-2.0
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
#
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from functools import total_ordering
import os
from pathlib import Path
import re
import time
from typing import Optional, Union
from datetime import datetime
from datetime import timedelta
import dns
import isctest.log
DEFAULT_TTL = 300
def _save_response(response, fname):
with open(fname, "w", encoding="utf-8") as file:
file.write(response.to_text())
def _query(server, qname, qtype, outfile=None):
query = dns.message.make_query(qname, qtype, use_edns=True, want_dnssec=True)
try:
response = dns.query.tcp(query, server.ip, port=server.ports.dns, timeout=3)
except dns.exception.Timeout:
isctest.log.debug(f"query timeout for query {qname} {qtype} to {server.ip}")
return None
if outfile is not None:
_save_response(response, outfile)
return response
@total_ordering
class KeyTimingMetadata:
"""
Represent a single timing information for a key.
These objects can be easily compared, support addition and subtraction of
timedelta objects or integers(value in seconds). A lack of timing metadata
in the key (value 0) should be represented with None rather than an
instance of this object.
"""
FORMAT = "%Y%m%d%H%M%S"
def __init__(self, timestamp: str):
if int(timestamp) <= 0:
raise ValueError(f'invalid timing metadata value: "{timestamp}"')
self.value = datetime.strptime(timestamp, self.FORMAT)
def __repr__(self):
return self.value.strftime(self.FORMAT)
def __str__(self) -> str:
return self.value.strftime(self.FORMAT)
def __add__(self, other: Union[timedelta, int]):
if isinstance(other, int):
other = timedelta(seconds=other)
result = KeyTimingMetadata.__new__(KeyTimingMetadata)
result.value = self.value + other
return result
def __sub__(self, other: Union[timedelta, int]):
if isinstance(other, int):
other = timedelta(seconds=other)
result = KeyTimingMetadata.__new__(KeyTimingMetadata)
result.value = self.value - other
return result
def __iadd__(self, other: Union[timedelta, int]):
if isinstance(other, int):
other = timedelta(seconds=other)
self.value += other
def __isub__(self, other: Union[timedelta, int]):
if isinstance(other, int):
other = timedelta(seconds=other)
self.value -= other
def __lt__(self, other: "KeyTimingMetadata"):
return self.value < other.value
def __eq__(self, other: object):
return isinstance(other, KeyTimingMetadata) and self.value == other.value
@staticmethod
def now() -> "KeyTimingMetadata":
result = KeyTimingMetadata.__new__(KeyTimingMetadata)
result.value = datetime.now()
return result
@total_ordering
class Key:
"""
Represent a key from a keyfile.
This object keeps track of its origin (keydir + name), can be used to
retrieve metadata from the underlying files and supports convenience
operations for KASP tests.
"""
def __init__(self, name: str, keydir: Optional[Union[str, Path]] = None):
self.name = name
if keydir is None:
self.keydir = Path()
else:
self.keydir = Path(keydir)
self.path = str(self.keydir / name)
self.keyfile = f"{self.path}.key"
self.statefile = f"{self.path}.state"
self.tag = int(self.name[-5:])
def get_timing(
self, metadata: str, must_exist: bool = True
) -> Optional[KeyTimingMetadata]:
regex = rf";\s+{metadata}:\s+(\d+).*"
with open(self.keyfile, "r", encoding="utf-8") as file:
for line in file:
match = re.match(regex, line)
if match is not None:
try:
return KeyTimingMetadata(match.group(1))
except ValueError:
break
if must_exist:
raise ValueError(
f'timing metadata "{metadata}" for key "{self.name}" invalid'
)
return None
def get_metadata(self, metadata: str, must_exist=True) -> str:
value = "undefined"
regex = rf"{metadata}:\s+(.*)"
with open(self.statefile, "r", encoding="utf-8") as file:
for line in file:
match = re.match(regex, line)
if match is not None:
value = match.group(1)
break
if must_exist and value == "undefined":
raise ValueError(
'state metadata "{metadata}" for key "{self.name}" undefined'
)
return value
def is_ksk(self) -> bool:
return self.get_metadata("KSK") == "yes"
def is_zsk(self) -> bool:
return self.get_metadata("ZSK") == "yes"
def dnskey_equals(self, value, cdnskey=False):
dnskey = value.split()
if cdnskey:
# fourth element is the rrtype
assert dnskey[3] == "CDNSKEY"
dnskey[3] = "DNSKEY"
dnskey_fromfile = []
rdata = " ".join(dnskey[:7])
with open(self.keyfile, "r", encoding="utf-8") as file:
for line in file:
if f"{rdata}" in line:
dnskey_fromfile = line.split()
pubkey_fromfile = "".join(dnskey_fromfile[7:])
pubkey_fromwire = "".join(dnskey[7:])
return pubkey_fromfile == pubkey_fromwire
def cds_equals(self, value, alg):
cds = value.split()
dsfromkey_command = [
os.environ.get("DSFROMKEY"),
"-T",
"3600",
"-a",
alg,
"-C",
"-w",
str(self.keyfile),
]
out = isctest.run.cmd(dsfromkey_command, log_stdout=True)
dsfromkey = out.stdout.decode("utf-8").split()
rdata_fromfile = " ".join(dsfromkey[:7])
rdata_fromwire = " ".join(cds[:7])
if rdata_fromfile != rdata_fromwire:
isctest.log.debug(
f"CDS RDATA MISMATCH: {rdata_fromfile} - {rdata_fromwire}"
)
return False
digest_fromfile = "".join(dsfromkey[7:]).lower()
digest_fromwire = "".join(cds[7:]).lower()
if digest_fromfile != digest_fromwire:
isctest.log.debug(
f"CDS DIGEST MISMATCH: {digest_fromfile} - {digest_fromwire}"
)
return False
return digest_fromfile == digest_fromwire
def __lt__(self, other: "Key"):
return self.name < other.name
def __eq__(self, other: object):
return isinstance(other, Key) and self.path == other.path
def __repr__(self):
return self.path
def check_zone_is_signed(server, zone):
addr = server.ip
fqdn = f"{zone}."
# wait until zone is fully signed
signed = False
for _ in range(10):
response = _query(server, fqdn, dns.rdatatype.NSEC)
if not isinstance(response, dns.message.Message):
isctest.log.debug(f"no response for {fqdn} NSEC from {addr}")
elif response.rcode() != dns.rcode.NOERROR:
rcode = dns.rcode.to_text(response.rcode())
isctest.log.debug(f"{rcode} response for {fqdn} NSEC from {addr}")
else:
has_nsec = False
has_rrsig = False
for rr in response.answer:
if not has_nsec:
has_nsec = rr.match(
dns.name.from_text(fqdn),
dns.rdataclass.IN,
dns.rdatatype.NSEC,
dns.rdatatype.NONE,
)
if not has_rrsig:
has_rrsig = rr.match(
dns.name.from_text(fqdn),
dns.rdataclass.IN,
dns.rdatatype.RRSIG,
dns.rdatatype.NSEC,
)
if not has_nsec:
isctest.log.debug(
f"missing apex {fqdn} NSEC record in response from {addr}"
)
if not has_rrsig:
isctest.log.debug(
f"missing {fqdn} NSEC signature in response from {addr}"
)
signed = has_nsec and has_rrsig
if signed:
break
time.sleep(1)
assert signed
def check_dnssec_verify(server, zone):
# Check if zone if DNSSEC valid with dnssec-verify.
fqdn = f"{zone}."
transfer = _query(server, fqdn, dns.rdatatype.AXFR)
if not isinstance(transfer, dns.message.Message):
isctest.log.debug(f"no response for {fqdn} AXFR from {server.ip}")
elif transfer.rcode() != dns.rcode.NOERROR:
rcode = dns.rcode.to_text(transfer.rcode())
isctest.log.debug(f"{rcode} response for {fqdn} AXFR from {server.ip}")
else:
zonefile = f"{zone}.axfr"
with open(zonefile, "w", encoding="utf-8") as file:
for rr in transfer.answer:
file.write(rr.to_text())
file.write("\n")
verify_command = [*os.environ.get("VERIFY").split(), "-z", "-o", zone, zonefile]
isctest.run.cmd(verify_command)
def check_dnssecstatus(server, zone, keys, policy=None, view=None):
# Call rndc dnssec -status on 'server' for 'zone'. Expect 'policy' in
# the output. This is a loose verification, it just tests if the right
# policy name is returned, and if all expected keys are listed.
response = ""
if view is None:
response = server.rndc("dnssec -status {}".format(zone), log=False)
else:
response = server.rndc("dnssec -status {} in {}".format(zone, view), log=False)
if policy is None:
assert "Zone does not have dnssec-policy" in response
return
assert "dnssec-policy: {}".format(policy) in response
for key in keys:
assert "key: {}".format(key.tag) in response
def _check_signatures(signatures, covers, fqdn, keys):
now = KeyTimingMetadata.now()
numsigs = 0
zrrsig = True
if covers in [dns.rdatatype.DNSKEY, dns.rdatatype.CDNSKEY, dns.rdatatype.CDS]:
zrrsig = False
krrsig = not zrrsig
for key in keys:
activate = key.get_timing("Activate")
inactive = key.get_timing("Inactive", must_exist=False)
active = now >= activate
retired = inactive is not None and inactive <= now
signing = active and not retired
if not signing:
for rrsig in signatures:
assert f"{key.tag} {fqdn}" not in rrsig
continue
if zrrsig and key.is_zsk():
has_rrsig = False
for rrsig in signatures:
if f"{key.tag} {fqdn}" in rrsig:
has_rrsig = True
break
assert has_rrsig
numsigs += 1
if zrrsig and not key.is_zsk():
for rrsig in signatures:
assert f"{key.tag} {fqdn}" not in rrsig
if krrsig and key.is_ksk():
has_rrsig = False
for rrsig in signatures:
if f"{key.tag} {fqdn}" in rrsig:
has_rrsig = True
break
assert has_rrsig
numsigs += 1
if krrsig and not key.is_ksk():
for rrsig in signatures:
assert f"{key.tag} {fqdn}" not in rrsig
return numsigs
def check_signatures(rrset, covers, fqdn, ksks, zsks):
# Check if signatures with covering type are signed with the right keys.
# The right keys are the ones that expect a signature and have the
# correct role.
numsigs = 0
signatures = []
for rr in rrset:
for rdata in rr:
rdclass = dns.rdataclass.to_text(rr.rdclass)
rdtype = dns.rdatatype.to_text(rr.rdtype)
rrsig = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
signatures.append(rrsig)
numsigs += _check_signatures(signatures, covers, fqdn, ksks)
numsigs += _check_signatures(signatures, covers, fqdn, zsks)
assert numsigs == len(signatures)
def _check_dnskeys(dnskeys, keys, cdnskey=False):
now = KeyTimingMetadata.now()
numkeys = 0
publish_md = "Publish"
delete_md = "Delete"
if cdnskey:
publish_md = f"Sync{publish_md}"
delete_md = f"Sync{delete_md}"
for key in keys:
publish = key.get_timing(publish_md)
delete = key.get_timing(delete_md, must_exist=False)
published = now >= publish
removed = delete is not None and delete <= now
if not published or removed:
for dnskey in dnskeys:
assert not key.dnskey_equals(dnskey, cdnskey=cdnskey)
continue
has_dnskey = False
for dnskey in dnskeys:
if key.dnskey_equals(dnskey, cdnskey=cdnskey):
has_dnskey = True
break
assert has_dnskey
numkeys += 1
return numkeys
def check_dnskeys(rrset, ksks, zsks, cdnskey=False):
# Check if the correct DNSKEY records are published. If the current time
# is between the timing metadata 'publish' and 'delete', the key must have
# a DNSKEY record published. If 'cdnskey' is True, check against CDNSKEY
# records instead.
numkeys = 0
dnskeys = []
for rr in rrset:
for rdata in rr:
rdclass = dns.rdataclass.to_text(rr.rdclass)
rdtype = dns.rdatatype.to_text(rr.rdtype)
dnskey = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
dnskeys.append(dnskey)
numkeys += _check_dnskeys(dnskeys, ksks, cdnskey=cdnskey)
if not cdnskey:
numkeys += _check_dnskeys(dnskeys, zsks)
assert numkeys == len(dnskeys)
# pylint: disable=too-many-locals
def check_cds(rrset, keys):
# Check if the correct CDS records are published. If the current time
# is between the timing metadata 'publish' and 'delete', the key must have
# a DNSKEY record published. If 'cdnskey' is True, check against CDNSKEY
# records instead.
now = KeyTimingMetadata.now()
numcds = 0
cdss = []
for rr in rrset:
for rdata in rr:
rdclass = dns.rdataclass.to_text(rr.rdclass)
rdtype = dns.rdatatype.to_text(rr.rdtype)
cds = f"{rr.name} {rr.ttl} {rdclass} {rdtype} {rdata}"
cdss.append(cds)
for key in keys:
assert key.is_ksk()
publish = key.get_timing("SyncPublish")
delete = key.get_timing("SyncDelete", must_exist=False)
published = now >= publish
removed = delete is not None and delete <= now
if not published or removed:
for cds in cdss:
assert not key.cds_equals(cds, "SHA-256")
continue
has_cds = False
for cds in cdss:
if key.cds_equals(cds, "SHA-256"):
has_cds = True
break
assert has_cds
numcds += 1
assert numcds == len(cdss)
def _query_rrset(server, fqdn, qtype):
response = _query(server, fqdn, qtype)
assert response.rcode() == dns.rcode.NOERROR
rrs = []
rrsigs = []
for rrset in response.answer:
if rrset.match(
dns.name.from_text(fqdn), dns.rdataclass.IN, dns.rdatatype.RRSIG, qtype
):
rrsigs.append(rrset)
elif rrset.match(
dns.name.from_text(fqdn), dns.rdataclass.IN, qtype, dns.rdatatype.NONE
):
rrs.append(rrset)
else:
assert False
return rrs, rrsigs
def check_apex(server, zone, ksks, zsks):
# Test the apex of a zone. This checks that the SOA and DNSKEY RRsets
# are signed correctly and with the appropriate keys.
fqdn = f"{zone}."
# test dnskey query
dnskeys, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.DNSKEY)
assert len(dnskeys) > 0
check_dnskeys(dnskeys, ksks, zsks)
assert len(rrsigs) > 0
check_signatures(rrsigs, dns.rdatatype.DNSKEY, fqdn, ksks, zsks)
# test soa query
soa, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.SOA)
assert len(soa) == 1
assert f"{zone}. {DEFAULT_TTL} IN SOA" in soa[0].to_text()
assert len(rrsigs) > 0
check_signatures(rrsigs, dns.rdatatype.SOA, fqdn, ksks, zsks)
# test cdnskey query
cdnskeys, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.CDNSKEY)
assert len(cdnskeys) > 0
check_dnskeys(cdnskeys, ksks, zsks, cdnskey=True)
assert len(rrsigs) > 0
check_signatures(rrsigs, dns.rdatatype.CDNSKEY, fqdn, ksks, zsks)
# test cds query
cds, rrsigs = _query_rrset(server, fqdn, dns.rdatatype.CDS)
assert len(cds) > 0
check_cds(cds, ksks)
assert len(rrsigs) > 0
check_signatures(rrsigs, dns.rdatatype.CDS, fqdn, ksks, zsks)
def check_subdomain(server, zone, ksks, zsks):
# Test an RRset below the apex and verify it is signed correctly.
fqdn = f"{zone}."
qname = f"a.{zone}."
qtype = dns.rdatatype.A
response = _query(server, qname, qtype)
assert response.rcode() == dns.rcode.NOERROR
match = f"{qname} {DEFAULT_TTL} IN A 10.0.0.1"
rrsigs = []
for rrset in response.answer:
if rrset.match(
dns.name.from_text(qname), dns.rdataclass.IN, dns.rdatatype.RRSIG, qtype
):
rrsigs.append(rrset)
else:
assert match in rrset.to_text()
assert len(rrsigs) > 0
check_signatures(rrsigs, qtype, fqdn, ksks, zsks)