2
0
mirror of https://gitlab.isc.org/isc-projects/bind9 synced 2025-08-29 21:47:59 +00:00

add helper functions to isctest

added some helper functions in isctest to reduce code repetition
in dnssec-related tests:

- isctest.check.adflag() - checks that a response contains AD=1
- isctest.check.noadflag() - checks that a response contains AD=0

- isctest.check.rdflag() - checks that a response contains RD=1
- isctest.check.nordflag() - checks that a response contains RD=0

- isctest.check.answer_count_eq() - checks the answer count is correct
- isctest.check.additional_count_eq() - same for authority count
- isctest.check.authority_count_eq() - same for additional count

- isctest.check.same_data() - check that two message have the
                              same rcode and data
- isctest.check.same_answer() - check that two message have the same
                                rcode and answer

- isctest.dnssec.msg() - a wrapper for dns.message.make_query() that
                         creates a query message similar to dig +dnssec:
                         use_edns=True, want_dnssec=True,
                         and flags are set to (RD|AD) by default, but
                         options exist to disable AD or enable CD.
                         (to generate non-DNSSEC queries, use
                         message.make_query() directly.)

(cherry picked from commit b69097f139154ca0d2177f35632400200d220bdc)
This commit is contained in:
Evan Hunt 2025-06-26 15:19:45 -07:00
parent 39e82071f4
commit 3a3bcd5aa1
4 changed files with 100 additions and 2 deletions

View File

@ -10,6 +10,7 @@
# information regarding copyright ownership. # information regarding copyright ownership.
from . import check from . import check
from . import dnssec
from . import instance from . import instance
from . import query from . import query
from . import kasp from . import kasp

View File

@ -13,6 +13,7 @@ import difflib
import shutil import shutil
from typing import Optional from typing import Optional
import dns.flags
import dns.rcode import dns.rcode
import dns.message import dns.message
import dns.zone import dns.zone
@ -41,6 +42,53 @@ def servfail(message: dns.message.Message) -> None:
rcode(message, dns_rcode.SERVFAIL) rcode(message, dns_rcode.SERVFAIL)
def adflag(message: dns.message.Message) -> None:
assert (message.flags & dns.flags.AD) != 0, str(message)
def noadflag(message: dns.message.Message) -> None:
assert (message.flags & dns.flags.AD) == 0, str(message)
def rdflag(message: dns.message.Message) -> None:
assert (message.flags & dns.flags.RD) != 0, str(message)
def nordflag(message: dns.message.Message) -> None:
assert (message.flags & dns.flags.RD) == 0, str(message)
def section_equal(sec1: list, sec2: list) -> None:
# convert an RRset to a normalized string (lower case, TTL=0)
# so it can be used as a set member.
def normalized(rrset):
ttl = rrset.ttl
rrset.ttl = 0
s = str(rrset).lower()
rrset.ttl = ttl
return s
# convert the section contents to sets before comparison,
# in case they aren't in the same sort order.
set1 = {normalized(item) for item in sec1}
set2 = {normalized(item) for item in sec2}
assert set1 == set2
def same_data(res1: dns.message.Message, res2: dns.message.Message):
assert res1.question == res2.question
section_equal(res1.answer, res2.answer)
section_equal(res1.authority, res2.authority)
section_equal(res1.additional, res2.additional)
assert res1.rcode() == res2.rcode()
def same_answer(res1: dns.message.Message, res2: dns.message.Message):
assert res1.question == res2.question
section_equal(res1.answer, res2.answer)
assert res1.rcode() == res2.rcode()
def rrsets_equal( def rrsets_equal(
first_rrset: dns.rrset.RRset, first_rrset: dns.rrset.RRset,
second_rrset: dns.rrset.RRset, second_rrset: dns.rrset.RRset,
@ -125,6 +173,30 @@ def empty_answer(message: dns.message.Message) -> None:
assert not message.answer, str(message) assert not message.answer, str(message)
def answer_count_eq(m: dns.message.Message, expected: int):
count = sum(max(1, len(rrs)) for rrs in m.answer)
assert count == expected, str(m)
def authority_count_eq(m: dns.message.Message, expected: int):
count = sum(max(1, len(rrs)) for rrs in m.authority)
assert count == expected, str(m)
def additional_count_eq(m: dns.message.Message, expected: int):
count = sum(max(1, len(rrs)) for rrs in m.additional)
# add one for the OPT?
opt = bool(m.opt) if hasattr(m, "opt") else bool(m.edns >= 0)
count += 1 if opt else 0
# add one for the TSIG?
tsig = bool(m.tsig) if hasattr(m, "tsig") else m.had_tsig
count += 1 if tsig else 0
assert count == expected, str(m)
def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None: def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None:
single_question(response) single_question(response)
single_question(query) single_question(query)

View File

@ -0,0 +1,25 @@
# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
#
# SPDX-License-Identifier: MPL-2.0
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
#
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.
from dns import flags, message
def msg(qname: str, qtype: str, **kwargs):
headerflags = flags.RD
# "ad" is on by default
if "ad" not in kwargs or not kwargs["ad"]:
headerflags |= flags.AD
# "cd" is off by default
if "cd" in kwargs and kwargs["cd"]:
headerflags |= flags.CD
return message.make_query(
qname, qtype, use_edns=True, want_dnssec=True, flags=headerflags
)

View File

@ -194,13 +194,13 @@ class NamedInstance:
""" """
return WatchLogFromHere(self.log.path, timeout) return WatchLogFromHere(self.log.path, timeout)
def reconfigure(self) -> None: def reconfigure(self, **kwargs) -> None:
""" """
Reconfigure this named `instance` and wait until reconfiguration is Reconfigure this named `instance` and wait until reconfiguration is
finished. Raise an `RNDCException` if reconfiguration fails. finished. Raise an `RNDCException` if reconfiguration fails.
""" """
with self.watch_log_from_here() as watcher: with self.watch_log_from_here() as watcher:
self.rndc("reconfig") self.rndc("reconfig", **kwargs)
watcher.wait_for_line("any newly configured zones are now loaded") watcher.wait_for_line("any newly configured zones are now loaded")
def _rndc_log(self, command: str, response: str) -> None: def _rndc_log(self, command: str, response: str) -> None: