2
0
mirror of https://gitlab.isc.org/isc-projects/kea synced 2025-09-01 06:25:34 +00:00

[trac419] update request handling logic

This commit is contained in:
chenzhengzhang
2011-02-24 20:28:19 +08:00
parent 9abd2de988
commit 40f74edaaf
3 changed files with 71 additions and 70 deletions

View File

@@ -88,20 +88,11 @@ class TestXfroutSession(unittest.TestCase):
request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
self.xfrsess = MyXfroutSession(request, None, None, self.log)
self.write_sock, self.read_sock = socket.socketpair()
self.xfrsess.server = Dbserver()
self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
def test_receive_query_message(self):
send_msg = b"\xd6=\x00\x00\x00\x01\x00"
msg_len = struct.pack('H', socket.htons(len(send_msg)))
self.write_sock.send(msg_len)
self.write_sock.send(send_msg)
recv_msg = self.xfrsess._receive_query_message(self.read_sock)
self.assertEqual(recv_msg, send_msg)
def test_parse_query_message(self):
[get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -321,8 +312,17 @@ class MyUnixSockServer(UnixSockServer):
class TestUnixSockServer(unittest.TestCase):
def setUp(self):
self.write_sock, self.read_sock = socket.socketpair()
self.unix = MyUnixSockServer()
def test_receive_query_message(self):
send_msg = b"\xd6=\x00\x00\x00\x01\x00"
msg_len = struct.pack('H', socket.htons(len(send_msg)))
self.write_sock.send(msg_len)
self.write_sock.send(send_msg)
recv_msg = self.unix._receive_query_message(self.read_sock)
self.assertEqual(recv_msg, send_msg)
def test_updata_config_data(self):
self.unix.update_config_data({'transfers_out':10 })
self.assertEqual(self.unix._max_transfers_out, 10)

View File

@@ -73,57 +73,25 @@ def get_rrset_len(rrset):
return len(bytes)
class XfroutSession(BaseRequestHandler):
def __init__(self, request, client_address, server, log):
class XfroutSession():
def __init__(self, sock_fd, request_data, server, log):
# The initializer for the superclass may call functions
# that need _log to be set, so we set it first
self._sock_fd = sock_fd
self._request_data = request_data
self._server = server
self._log = log
BaseRequestHandler.__init__(self, request, client_address, server)
self.handle()
def handle(self):
''' Handle a xfrout query. First, xfrout server receive
socket fd and query message from auth. Then, send xfrout
response via the socket fd.'''
sock_fd = recv_fd(self.request.fileno())
if sock_fd < 0:
# This may happen when one xfrout process try to connect to
# xfrout unix socket server, to check whether there is another
# xfrout running.
if sock_fd == XFR_FD_RECEIVE_FAIL:
self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
return
# receive query msg
msgdata = self._receive_query_message(self.request)
if not msgdata:
return
''' Handle a xfrout query, send xfrout response '''
try:
self.dns_xfrout_start(sock_fd, msgdata)
self.dns_xfrout_start(self._sock_fd, self._request_data)
#TODO, avoid catching all exceptions
except Exception as e:
self._log.log_message("error", str(e))
os.close(sock_fd)
def _receive_query_message(self, sock):
''' receive query message from sock'''
# receive data length
data_len = sock.recv(2)
if not data_len:
return None
msg_len = struct.unpack('!H', data_len)[0]
# receive data
recv_size = 0
msgdata = b''
while recv_size < msg_len:
data = sock.recv(msg_len - recv_size)
if not data:
return None
recv_size += len(data)
msgdata += data
return msgdata
os.close(self._sock_fd)
def _parse_query_message(self, mdata):
''' parse query message to [socket,message]'''
@@ -176,7 +144,7 @@ class XfroutSession(BaseRequestHandler):
def _zone_is_empty(self, zone):
if sqlite3_ds.get_zone_soa(zone, self.server.get_db_file()):
if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
return False
return True
@@ -184,7 +152,7 @@ class XfroutSession(BaseRequestHandler):
def _zone_exist(self, zonename):
# Find zone in datasource, should this works? maybe should ask
# config manager.
soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
soa = sqlite3_ds.get_zone_soa(zonename, self._server.get_db_file())
if soa:
return True
return False
@@ -202,7 +170,7 @@ class XfroutSession(BaseRequestHandler):
return Rcode.SERVFAIL()
#TODO, check allow_transfer
if not self.server.increase_transfers_counter():
if not self._server.increase_transfers_counter():
return Rcode.REFUSED()
return Rcode.NOERROR()
@@ -228,7 +196,7 @@ class XfroutSession(BaseRequestHandler):
except Exception as err:
self._log.log_message("error", str(err))
self.server.decrease_transfers_counter()
self._server.decrease_transfers_counter()
return
@@ -275,14 +243,14 @@ class XfroutSession(BaseRequestHandler):
#TODO, there should be a better way to insert rrset.
msg.make_response()
msg.set_header_flag(Message.HEADERFLAG_AA)
soa_record = sqlite3_ds.get_zone_soa(zone_name, self.server.get_db_file())
soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
rrset_soa = self._create_rrset_from_db_record(soa_record)
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
message_upper_len = get_rrset_len(rrset_soa)
for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
if self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
if self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
self._log.log_message("info", "xfrout process is being shutdown")
return
@@ -324,7 +292,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
self.update_config_data(config_data)
self._cc = cc
def _handle_request_noblock(self):
def handle_request(self):
'''Rewrite _handle_request_noblock() from parent class ThreadingUnixStreamServer,
enable server handle a request until shutdown or xfrout client is closed.'''
try:
@@ -359,19 +327,52 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
self.close_request(request)
break
def process_request_thread(self, request, client_address):
''' Rewrite process_request_thread() from parent class ThreadingUnixStreamServer,
server won't close the connection after handling a xfrout query, the connection
should be kept for handling upcoming xfrout queries.'''
try:
self.finish_request(request, client_address)
except Exception as e:
self.handle_error(request, client_address)
self.close_request(request)
def _receive_query_message(self, sock):
''' receive request message from sock'''
# receive data length
data_len = sock.recv(2)
if not data_len:
return None
msg_len = struct.unpack('!H', data_len)[0]
# receive data
recv_size = 0
msgdata = b''
while recv_size < msg_len:
data = sock.recv(msg_len - recv_size)
if not data:
return None
recv_size += len(data)
msgdata += data
def finish_request(self, request, client_address):
return msgdata
def process_request(self, request, client_address):
"""Receive socket fd and query message from auth, then
start a new thread to process the request."""
sock_fd = recv_fd(request.fileno())
if sock_fd < 0:
# This may happen when one xfrout process try to connect to
# xfrout unix socket server, to check whether there is another
# xfrout running.
if sock_fd == XFR_FD_RECEIVE_FAIL:
self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
return
# receive request msg
request_data = self._receive_query_message(request)
if not request_data:
return
t = threading.Thread(target = self.finish_request,
args = (sock_fd, request_data, client_address))
if self.daemon_threads:
t.daemon = True
t.start()
def finish_request(self, sock_fd, request_data, client_address):
'''Finish one request by instantiating RequestHandlerClass.'''
self.RequestHandlerClass(request, client_address, self, self._log)
self.RequestHandlerClass(sock_fd, request_data, self, self._log)
def _remove_unused_sock_file(self, sock_file):
'''Try to remove the socket file. If the file is being used

View File

@@ -79,7 +79,7 @@ class NoPollMixIn:
break
else:
# Create a new thread to handle requests for each auth
threading.Thread(target=self._handle_request_noblock).start()
threading.Thread(target=self.handle_request).start()
self._is_shut_down.set()