diff --git a/bin/tests/system/dnssec/tests_nsec3.py b/bin/tests/system/dnssec/tests_nsec3.py index 1a7dbe71ee..83e63a7059 100755 --- a/bin/tests/system/dnssec/tests_nsec3.py +++ b/bin/tests/system/dnssec/tests_nsec3.py @@ -161,17 +161,10 @@ def test_dnssec_nsec3_nxdomain(server, name: dns.name.Name, named_port: int) -> noqname_test(server, name, named_port) -@strategies.composite -def generate_subdomain_of_existing_name(draw): - existing = draw(strategies.sampled_from(sorted(KNOWN_NAMES))) - subdomain = draw(isctest.hypothesis.strategies.dns_names(suffix=existing)) - return subdomain - - @pytest.mark.parametrize( "server", [pytest.param(AUTH, id="ns3"), pytest.param(RESOLVER, id="ns4")] ) -@given(name=generate_subdomain_of_existing_name()) +@given(name=dns_names(suffix=KNOWN_NAMES)) def test_dnssec_nsec3_subdomain_nxdomain( server, name: dns.name.Name, named_port: int ) -> None: diff --git a/bin/tests/system/isctest/hypothesis/strategies.py b/bin/tests/system/isctest/hypothesis/strategies.py index 0828496360..a3f9eac2b2 100644 --- a/bin/tests/system/isctest/hypothesis/strategies.py +++ b/bin/tests/system/isctest/hypothesis/strategies.py @@ -11,7 +11,8 @@ # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. -from typing import List +import collections.abc +from typing import List, Union from warnings import warn from hypothesis.strategies import ( @@ -22,6 +23,7 @@ from hypothesis.strategies import ( just, nothing, permutations, + sampled_from, ) import dns.name @@ -37,7 +39,9 @@ def dns_names( draw, *, prefix: dns.name.Name = dns.name.empty, - suffix: dns.name.Name = dns.name.root, + suffix: Union[ + dns.name.Name, collections.abc.Iterable[dns.name.Name] + ] = dns.name.root, min_labels: int = 1, max_labels: int = 128, ) -> dns.name.Name: @@ -71,6 +75,14 @@ def dns_names( """ prefix = prefix.relativize(dns.name.root) + # Python str is iterable, but that's most probably not what user actually wanted + if isinstance(suffix, str): + raise NotImplementedError( + "ambiguous API use, convert suffix to Name or list to express intent" + ) + if isinstance(suffix, collections.abc.Iterable): + suffix = draw(sampled_from(sorted(suffix))) + assert isinstance(suffix, dns.name.Name) suffix = suffix.derelativize(dns.name.root) try: