diff --git a/src/bin/xfrout/tests/xfrout_test.py b/src/bin/xfrout/tests/xfrout_test.py index ba9c258e9f..fce9196023 100644 --- a/src/bin/xfrout/tests/xfrout_test.py +++ b/src/bin/xfrout/tests/xfrout_test.py @@ -88,20 +88,11 @@ class TestXfroutSession(unittest.TestCase): request = MySocket(socket.AF_INET,socket.SOCK_STREAM) self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False ) self.xfrsess = MyXfroutSession(request, None, None, self.log) - self.write_sock, self.read_sock = socket.socketpair() self.xfrsess.server = Dbserver() self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01') self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM) self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200') - def test_receive_query_message(self): - send_msg = b"\xd6=\x00\x00\x00\x01\x00" - msg_len = struct.pack('H', socket.htons(len(send_msg))) - self.write_sock.send(msg_len) - self.write_sock.send(send_msg) - recv_msg = self.xfrsess._receive_query_message(self.read_sock) - self.assertEqual(recv_msg, send_msg) - def test_parse_query_message(self): [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(get_rcode.to_text(), "NOERROR") @@ -321,8 +312,17 @@ class MyUnixSockServer(UnixSockServer): class TestUnixSockServer(unittest.TestCase): def setUp(self): + self.write_sock, self.read_sock = socket.socketpair() self.unix = MyUnixSockServer() + def test_receive_query_message(self): + send_msg = b"\xd6=\x00\x00\x00\x01\x00" + msg_len = struct.pack('H', socket.htons(len(send_msg))) + self.write_sock.send(msg_len) + self.write_sock.send(send_msg) + recv_msg = self.unix._receive_query_message(self.read_sock) + self.assertEqual(recv_msg, send_msg) + def test_updata_config_data(self): self.unix.update_config_data({'transfers_out':10 }) self.assertEqual(self.unix._max_transfers_out, 10) diff --git a/src/bin/xfrout/xfrout.py.in b/src/bin/xfrout/xfrout.py.in index d0ce3a822e..c6eb356bea 100644 --- a/src/bin/xfrout/xfrout.py.in +++ b/src/bin/xfrout/xfrout.py.in @@ -73,57 +73,25 @@ def get_rrset_len(rrset): return len(bytes) -class XfroutSession(BaseRequestHandler): - def __init__(self, request, client_address, server, log): +class XfroutSession(): + def __init__(self, sock_fd, request_data, server, log): # The initializer for the superclass may call functions # that need _log to be set, so we set it first + self._sock_fd = sock_fd + self._request_data = request_data + self._server = server self._log = log - BaseRequestHandler.__init__(self, request, client_address, server) + self.handle() def handle(self): - ''' Handle a xfrout query. First, xfrout server receive - socket fd and query message from auth. Then, send xfrout - response via the socket fd.''' - sock_fd = recv_fd(self.request.fileno()) - if sock_fd < 0: - # This may happen when one xfrout process try to connect to - # xfrout unix socket server, to check whether there is another - # xfrout running. - if sock_fd == XFR_FD_RECEIVE_FAIL: - self._log.log_message("error", "Failed to receive the file descriptor for XFR connection") - return - - # receive query msg - msgdata = self._receive_query_message(self.request) - if not msgdata: - return - + ''' Handle a xfrout query, send xfrout response ''' try: - self.dns_xfrout_start(sock_fd, msgdata) + self.dns_xfrout_start(self._sock_fd, self._request_data) #TODO, avoid catching all exceptions except Exception as e: self._log.log_message("error", str(e)) - os.close(sock_fd) - - def _receive_query_message(self, sock): - ''' receive query message from sock''' - # receive data length - data_len = sock.recv(2) - if not data_len: - return None - msg_len = struct.unpack('!H', data_len)[0] - # receive data - recv_size = 0 - msgdata = b'' - while recv_size < msg_len: - data = sock.recv(msg_len - recv_size) - if not data: - return None - recv_size += len(data) - msgdata += data - - return msgdata + os.close(self._sock_fd) def _parse_query_message(self, mdata): ''' parse query message to [socket,message]''' @@ -176,7 +144,7 @@ class XfroutSession(BaseRequestHandler): def _zone_is_empty(self, zone): - if sqlite3_ds.get_zone_soa(zone, self.server.get_db_file()): + if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()): return False return True @@ -184,7 +152,7 @@ class XfroutSession(BaseRequestHandler): def _zone_exist(self, zonename): # Find zone in datasource, should this works? maybe should ask # config manager. - soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file()) + soa = sqlite3_ds.get_zone_soa(zonename, self._server.get_db_file()) if soa: return True return False @@ -202,7 +170,7 @@ class XfroutSession(BaseRequestHandler): return Rcode.SERVFAIL() #TODO, check allow_transfer - if not self.server.increase_transfers_counter(): + if not self._server.increase_transfers_counter(): return Rcode.REFUSED() return Rcode.NOERROR() @@ -228,7 +196,7 @@ class XfroutSession(BaseRequestHandler): except Exception as err: self._log.log_message("error", str(err)) - self.server.decrease_transfers_counter() + self._server.decrease_transfers_counter() return @@ -275,14 +243,14 @@ class XfroutSession(BaseRequestHandler): #TODO, there should be a better way to insert rrset. msg.make_response() msg.set_header_flag(Message.HEADERFLAG_AA) - soa_record = sqlite3_ds.get_zone_soa(zone_name, self.server.get_db_file()) + soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file()) rrset_soa = self._create_rrset_from_db_record(soa_record) msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) message_upper_len = get_rrset_len(rrset_soa) - for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()): - if self.server._shutdown_event.is_set(): # Check if xfrout is shutdown + for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()): + if self._server._shutdown_event.is_set(): # Check if xfrout is shutdown self._log.log_message("info", "xfrout process is being shutdown") return @@ -324,7 +292,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer): self.update_config_data(config_data) self._cc = cc - def _handle_request_noblock(self): + def handle_request(self): '''Rewrite _handle_request_noblock() from parent class ThreadingUnixStreamServer, enable server handle a request until shutdown or xfrout client is closed.''' try: @@ -359,19 +327,52 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer): self.close_request(request) break - def process_request_thread(self, request, client_address): - ''' Rewrite process_request_thread() from parent class ThreadingUnixStreamServer, - server won't close the connection after handling a xfrout query, the connection - should be kept for handling upcoming xfrout queries.''' - try: - self.finish_request(request, client_address) - except Exception as e: - self.handle_error(request, client_address) - self.close_request(request) + def _receive_query_message(self, sock): + ''' receive request message from sock''' + # receive data length + data_len = sock.recv(2) + if not data_len: + return None + msg_len = struct.unpack('!H', data_len)[0] + # receive data + recv_size = 0 + msgdata = b'' + while recv_size < msg_len: + data = sock.recv(msg_len - recv_size) + if not data: + return None + recv_size += len(data) + msgdata += data - def finish_request(self, request, client_address): + return msgdata + + def process_request(self, request, client_address): + """Receive socket fd and query message from auth, then + start a new thread to process the request.""" + sock_fd = recv_fd(request.fileno()) + if sock_fd < 0: + # This may happen when one xfrout process try to connect to + # xfrout unix socket server, to check whether there is another + # xfrout running. + if sock_fd == XFR_FD_RECEIVE_FAIL: + self._log.log_message("error", "Failed to receive the file descriptor for XFR connection") + return + + # receive request msg + request_data = self._receive_query_message(request) + if not request_data: + return + + t = threading.Thread(target = self.finish_request, + args = (sock_fd, request_data, client_address)) + if self.daemon_threads: + t.daemon = True + t.start() + + + def finish_request(self, sock_fd, request_data, client_address): '''Finish one request by instantiating RequestHandlerClass.''' - self.RequestHandlerClass(request, client_address, self, self._log) + self.RequestHandlerClass(sock_fd, request_data, self, self._log) def _remove_unused_sock_file(self, sock_file): '''Try to remove the socket file. If the file is being used diff --git a/src/lib/python/isc/util/socketserver_mixin.py b/src/lib/python/isc/util/socketserver_mixin.py index b9c1aafcfe..e9852c473c 100644 --- a/src/lib/python/isc/util/socketserver_mixin.py +++ b/src/lib/python/isc/util/socketserver_mixin.py @@ -79,7 +79,7 @@ class NoPollMixIn: break else: # Create a new thread to handle requests for each auth - threading.Thread(target=self._handle_request_noblock).start() + threading.Thread(target=self.handle_request).start() self._is_shut_down.set()