diff --git a/src/bin/xfrin/tests/xfrin_test.py b/src/bin/xfrin/tests/xfrin_test.py index 0ccbbb8139..c9e25babaa 100644 --- a/src/bin/xfrin/tests/xfrin_test.py +++ b/src/bin/xfrin/tests/xfrin_test.py @@ -534,6 +534,35 @@ class TestXfrin(unittest.TestCase): self.assertEqual(self.xfr.command_handler("retransfer", self.args)['result'][0], 0) + def test_command_handler_retransfer_short_command1(self): + # try it when only specifying the zone name (of unknown zone) + short_args = {} + short_args['zone_name'] = TEST_ZONE_NAME + self.assertEqual(self.xfr.command_handler("retransfer", + short_args)['result'][0], 0) + + def test_command_handler_retransfer_short_command2(self): + # try it when only specifying the zone name (of unknown zone) + short_args = {} + short_args['zone_name'] = TEST_ZONE_NAME + "." + self.assertEqual(self.xfr.command_handler("retransfer", + short_args)['result'][0], 0) + + def test_command_handler_retransfer_short_command3(self): + # try it when only specifying the zone name (of known zone) + short_args = {} + short_args['zone_name'] = TEST_ZONE_NAME + + zones = { 'zones': [ + { 'name': TEST_ZONE_NAME, + 'master_addr': TEST_MASTER_IPV4_ADDRESS, + 'master_port': TEST_MASTER_PORT + } + ]} + self.xfr.config_handler(zones) + self.assertEqual(self.xfr.command_handler("retransfer", + short_args)['result'][0], 0) + def test_command_handler_retransfer_badcommand(self): self.args['master'] = 'invalid' self.assertEqual(self.xfr.command_handler("retransfer", @@ -574,7 +603,21 @@ class TestXfrin(unittest.TestCase): def test_command_handler_notify(self): # at this level, refresh is no different than retransfer. self.args['master'] = TEST_MASTER_IPV6_ADDRESS - # ...but right now we disable the feature due to security concerns. + # ...but the zone is unknown so this would return an error + self.assertEqual(self.xfr.command_handler("notify", + self.args)['result'][0], 1) + + def test_command_handler_notify_known_zone(self): + # try it with a known zone + self.args['master'] = TEST_MASTER_IPV6_ADDRESS + + zones = { 'zones': [ + { 'name': TEST_ZONE_NAME, + 'master_addr': TEST_MASTER_IPV4_ADDRESS, + 'master_port': TEST_MASTER_PORT + } + ]} + self.xfr.config_handler(zones) self.assertEqual(self.xfr.command_handler("notify", self.args)['result'][0], 0) @@ -586,20 +629,37 @@ class TestXfrin(unittest.TestCase): self.assertEqual(self.xfr.config_handler({'transfers_in': 3})['result'][0], 0) self.assertEqual(self.xfr._max_transfers_in, 3) - def test_command_handler_masters(self): - master_info = {'master_addr': '1.1.1.1', 'master_port':53} - self.assertEqual(self.xfr.config_handler(master_info)['result'][0], 0) + def test_command_handler_zones(self): + zones = { 'zones': [ + { 'name': 'test.com.', + 'master_addr': '1.1.1.1', + 'master_port': 53 + } + ]} + self.assertEqual(self.xfr.config_handler(zones)['result'][0], 0) - master_info = {'master_addr': '1111.1.1.1', 'master_port':53 } - self.assertEqual(self.xfr.config_handler(master_info)['result'][0], 1) + zones = { 'zones': [ + { 'master_addr': '1.1.1.1', + 'master_port': 53 + } + ]} + self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1) - master_info = {'master_addr': '2.2.2.2', 'master_port':530000 } - self.assertEqual(self.xfr.config_handler(master_info)['result'][0], 1) + zones = { 'zones': [ + { 'name': 'test.com', + 'master_addr': 'badaddress', + 'master_port': 53 + } + ]} + self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1) - master_info = {'master_addr': '2.2.2.2', 'master_port':53 } - self.xfr.config_handler(master_info) - self.assertEqual(self.xfr._master_addr, '2.2.2.2') - self.assertEqual(self.xfr._master_port, 53) + zones = { 'zones': [ + { 'name': 'test.com', + 'master_addr': '1.1.1.1', + 'master_port': 'bad_port' + } + ]} + self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1) def raise_interrupt(): diff --git a/src/bin/xfrin/xfrin.py.in b/src/bin/xfrin/xfrin.py.in index 1bf46c1b5c..0d112f5c88 100755 --- a/src/bin/xfrin/xfrin.py.in +++ b/src/bin/xfrin/xfrin.py.in @@ -70,6 +70,9 @@ def log_error(msg): class XfrinException(Exception): pass +class XfrinConfigException(Exception): + pass + class XfrinConnection(asyncore.dispatcher): '''Do xfrin in this class. ''' @@ -378,12 +381,41 @@ class XfrinRecorder: self._lock.release() return ret +class ZoneInfo: + def __init__(self, config_data): + """Creates a zone_info with the config data element as + specified by the 'zones' list in xfrin.spec""" + self.name = config_data.get('name') + self.class_str = config_data.get('class') or 'IN' + + if self.name is None: + raise XfrinConfigException("Configuration zones list " + "element does not contain " + "'name' attribute") + + # add the root dot if the user forgot + if len(self.name) > 0 and self.name[-1] != '.': + self.name += '.' + self.master_addr_str = config_data.get('master_addr') or DEFAULT_MASTER + self.master_port_str = config_data.get('master_port') or DEFAULT_MASTER_PORT + try: + self.master_addr = isc.net.parse.addr_parse(self.master_addr_str) + self.master_port = isc.net.parse.port_parse(self.master_port_str) + except ValueError: + errmsg = "bad format for zone's master: " + str(config_data) + log_error(errmsg) + raise XfrinConfigException(errmsg) + + self.tsig_key_str = config_data.get('tsig_key') or None + + def get_master_addr_info(self): + return (self.master_addr.family, socket.SOCK_STREAM, + (self.master_addr_str, self.master_port)) + class Xfrin: def __init__(self, verbose = False): self._max_transfers_in = 10 - #TODO, this is the temp way to set the zone's master. - self._master_addr = DEFAULT_MASTER - self._master_port = DEFAULT_MASTER_PORT + self._zones = {} self._cc_setup() self.recorder = XfrinRecorder() self._shutdown_event = threading.Event() @@ -402,10 +434,7 @@ class Xfrin: self.command_handler) self._module_cc.start() config_data = self._module_cc.get_full_config() - self._max_transfers_in = config_data.get("transfers_in") - self._master_addr = config_data.get('master_addr') or self._master_addr - self._master_port = config_data.get('master_port') or self._master_port - self._tsig_key_str = config_data.get('tsig_key') or None + self.config_handler(config_data) def _cc_check_command(self): '''This is a straightforward wrapper for cc.check_command, @@ -413,22 +442,34 @@ class Xfrin: of unit tests.''' self._module_cc.check_command(False) + def _get_zone_info(self, name, class_str = "IN"): + """Returns the ZoneInfo object containing the configured data + for the given zone name. If the zone name did not have any + data, returns None""" + # add the root dot if the user forgot + if len(name) > 0 and name[-1] != '.': + name += '.' + if (name, class_str) in self._zones: + return self._zones[(name, class_str)] + else: + return None + + def _clear_zone_info(self): + self._zones = {} + + def _add_zone_info(self, zone_info): + self._zones[(zone_info.name, zone_info.class_str)] = zone_info + def config_handler(self, new_config): self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in - self._tsig_key_str = new_config.get('tsig_key') or None - if ('master_addr' in new_config) or ('master_port' in new_config): - # User should change the port and address together. - try: - addr = new_config.get('master_addr') or self._master_addr - port = new_config.get('master_port') or self._master_port - isc.net.parse.addr_parse(addr) - isc.net.parse.port_parse(port) - self._master_addr = addr - self._master_port = port - except ValueError: - errmsg = "bad format for zone's master: " + str(new_config) - log_error(errmsg) - return create_answer(1, errmsg) + if 'zones' in new_config: + self._clear_zone_info() + for zone_config in new_config.get('zones'): + try: + zone_info = ZoneInfo(zone_config) + self._add_zone_info(zone_info) + except XfrinConfigException as xce: + return create_answer(1, str(xce)) return create_answer(0) @@ -454,14 +495,21 @@ class Xfrin: # specify the notifyfrom address and port, according the RFC1996, zone # transfer should starts first from the notifyfrom, but now, let 'TODO' it. (zone_name, rrclass) = self._parse_zone_name_and_class(args) - (master_addr) = build_addr_info(self._master_addr, self._master_port) - ret = self.xfrin_start(zone_name, - rrclass, - self._get_db_file(), - master_addr, - self._tsig_key_str, - True) - answer = create_answer(ret[0], ret[1]) + zone_info = self._get_zone_info(zone_name) + if zone_info is None: + # TODO what to do? no info known about zone. defaults? + errmsg = "Got notification to retransfer unknown zone " + zone_name + log_error(errmsg) + answer = create_answer(1, errmsg) + else: + master_addr = zone_info.get_master_addr_info() + ret = self.xfrin_start(zone_name, + rrclass, + self._get_db_file(), + master_addr, + zone_info.tsig_key_str, + True) + answer = create_answer(ret[0], ret[1]) elif command == 'retransfer' or command == 'refresh': # Xfrin receives the retransfer/refresh from cmdctl(sent by bindctl). @@ -469,12 +517,16 @@ class Xfrin: # master address, or else do transfer from the configured masters. (zone_name, rrclass) = self._parse_zone_name_and_class(args) master_addr = self._parse_master_and_port(args) + zone_info = self._get_zone_info(zone_name) + tsig_key_str = None + if zone_info: + tsig_key_str = zone_info.tsig_key_str db_file = args.get('db_file') or self._get_db_file() ret = self.xfrin_start(zone_name, rrclass, db_file, master_addr, - self._tsig_key_str, + tsig_key_str, (False if command == 'retransfer' else True)) answer = create_answer(ret[0], ret[1]) @@ -502,8 +554,24 @@ class Xfrin: return zone_name, rrclass def _parse_master_and_port(self, args): - port = args.get('port') or self._master_port - master = args.get('master') or self._master_addr + # check if we have configured info about this zone, in case + # port or master are not specified + zone_info = self._get_zone_info(args.get('zone_name')) + + port = args.get('port') + if port is None: + if zone_info is not None: + port = zone_info.master_port_str + else: + port = DEFAULT_MASTER_PORT + + master = args.get('master') + if master is None: + if zone_info is not None: + master = zone_info.master_addr_str + else: + master = DEFAULT_MASTER + return build_addr_info(master, port) def _get_db_file(self): diff --git a/src/bin/xfrin/xfrin.spec b/src/bin/xfrin/xfrin.spec index 46bad69f48..a3e62cefc4 100644 --- a/src/bin/xfrin/xfrin.spec +++ b/src/bin/xfrin/xfrin.spec @@ -9,21 +9,43 @@ "item_optional": false, "item_default": 10 }, - { - "item_name": "master_addr", - "item_type": "string", + { "item_name": "zones", + "item_type": "list", "item_optional": false, - "item_default": "" - }, - { "item_name": "master_port", - "item_type": "integer", - "item_optional": false, - "item_default": 53 - }, - { "item_name": "tsig_key", - "item_type": "string", - "item_optional": true, - "item_default": "" + "item_default": [], + "list_item_spec": + { "item_type": "map", + "item_name": "zone_info", + "item_optional": false, + "item_default": {}, + "map_item_spec": [ + { "item_name": "name", + "item_type": "string", + "item_optional": false, + "item_default": "" + }, + { "item_name": "class", + "item_type": "string", + "item_optional": false, + "item_default": "IN" + }, + { + "item_name": "master_addr", + "item_type": "string", + "item_optional": false, + "item_default": "" + }, + { "item_name": "master_port", + "item_type": "integer", + "item_optional": false, + "item_default": 53 + }, + { "item_name": "tsig_key", + "item_type": "string", + "item_optional": true + } + ] + } } ], "commands": [ diff --git a/src/bin/xfrout/xfrout.py.in b/src/bin/xfrout/xfrout.py.in index 17ca3ebff9..c7887111a6 100755 --- a/src/bin/xfrout/xfrout.py.in +++ b/src/bin/xfrout/xfrout.py.in @@ -301,6 +301,11 @@ class XfroutSession(): self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len) +class ZoneInfo: + def __init__(self, zone_config): + self.name = zone_config.get('name') + self.tsig_key_str = zone_config.get('tsig_key') + class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer): '''The unix domain socket server which accept xfr query sent from auth server.''' @@ -450,6 +455,11 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer): self._lock.acquire() self._max_transfers_out = new_config.get('transfers_out') self._log.log_message('info', 'max transfer out : %d', self._max_transfers_out) + zones = new_config.get('zones') + if zones is not None: + for zone_config in zones: + zone_info = ZoneInfo(zone_config) + self.zones[zone_info.name] = zone_info self._lock.release() self._log.log_message('info', 'update config data complete.') diff --git a/src/bin/xfrout/xfrout.spec.pre.in b/src/bin/xfrout/xfrout.spec.pre.in index 941db72a93..96b9570b6a 100644 --- a/src/bin/xfrout/xfrout.spec.pre.in +++ b/src/bin/xfrout/xfrout.spec.pre.in @@ -37,6 +37,29 @@ "item_type": "integer", "item_optional": false, "item_default": 1048576 + }, + { + "item_name": "zones", + "item_type": "list", + "item_optional": false, + "item_default": [], + "list_item_spec": + { "item_name": "zone_info", + "item_type": "map", + "item_optional": false, + "item_default": {}, + "map_item_spec": [ + { "item_name": "name", + "item_type": "string", + "item_optional": false, + "item_default": "" + }, + { "item_name": "tsig_key", + "item_type": "string", + "item_optional": true + } + ] + } } ], "commands": [