mirror of
https://gitlab.isc.org/isc-projects/kea
synced 2025-08-31 22:15:23 +00:00
Implement directed messages, and a test for it
git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@334 e5f2f494-b856-4b98-b285-d166d9295462
This commit is contained in:
@@ -87,6 +87,7 @@ class MsgQ:
|
|||||||
self.connection_counter = random.random()
|
self.connection_counter = random.random()
|
||||||
self.hostname = socket.gethostname()
|
self.hostname = socket.gethostname()
|
||||||
self.subs = SubscriptionManager()
|
self.subs = SubscriptionManager()
|
||||||
|
self.lnames = {}
|
||||||
|
|
||||||
def setup_poller(self):
|
def setup_poller(self):
|
||||||
"""Set up the poll thing. Internal function."""
|
"""Set up the poll thing. Internal function."""
|
||||||
@@ -117,6 +118,8 @@ class MsgQ:
|
|||||||
newsocket, ipaddr = self.listen_socket.accept()
|
newsocket, ipaddr = self.listen_socket.accept()
|
||||||
sys.stderr.write("Connection\n")
|
sys.stderr.write("Connection\n")
|
||||||
self.sockets[newsocket.fileno()] = newsocket
|
self.sockets[newsocket.fileno()] = newsocket
|
||||||
|
lname = self.newlname()
|
||||||
|
self.lnames[lname] = newsocket
|
||||||
self.poller.register(newsocket, select.POLLIN)
|
self.poller.register(newsocket, select.POLLIN)
|
||||||
|
|
||||||
def process_socket(self, fd):
|
def process_socket(self, fd):
|
||||||
@@ -132,6 +135,8 @@ class MsgQ:
|
|||||||
"""Fully close down the socket."""
|
"""Fully close down the socket."""
|
||||||
self.poller.unregister(sock)
|
self.poller.unregister(sock)
|
||||||
self.subs.unsubscribe_all(sock)
|
self.subs.unsubscribe_all(sock)
|
||||||
|
lname = [ k for k, v in self.lnames.items() if v == sock ][0]
|
||||||
|
del self.lnames[lname]
|
||||||
sock.close()
|
sock.close()
|
||||||
self.sockets[fd] = None
|
self.sockets[fd] = None
|
||||||
sys.stderr.write("Closing socket fd %d\n" % fd)
|
sys.stderr.write("Closing socket fd %d\n" % fd)
|
||||||
@@ -232,16 +237,20 @@ class MsgQ:
|
|||||||
return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
|
return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
|
||||||
|
|
||||||
def process_command_getlname(self, sock, routing, data):
|
def process_command_getlname(self, sock, routing, data):
|
||||||
env = { "type" : "getlname" }
|
lname = [ k for k, v in self.lnames.items() if v == sock ][0]
|
||||||
reply = { "lname" : self.newlname() }
|
self.sendmsg(sock, { "type" : "getlname" }, { "lname" : lname })
|
||||||
self.sendmsg(sock, env, reply)
|
|
||||||
|
|
||||||
def process_command_send(self, sock, routing, data):
|
def process_command_send(self, sock, routing, data):
|
||||||
group = routing["group"]
|
group = routing["group"]
|
||||||
instance = routing["instance"]
|
instance = routing["instance"]
|
||||||
|
to = routing["to"]
|
||||||
if group == None or instance == None:
|
if group == None or instance == None:
|
||||||
return # ignore invalid packets entirely
|
return # ignore invalid packets entirely
|
||||||
|
|
||||||
|
if to == "*":
|
||||||
sockets = self.subs.find(group, instance)
|
sockets = self.subs.find(group, instance)
|
||||||
|
else:
|
||||||
|
sockets = [ self.lnames[to] ]
|
||||||
|
|
||||||
msg = self.preparemsg(routing, data)
|
msg = self.preparemsg(routing, data)
|
||||||
|
|
||||||
@@ -253,8 +262,7 @@ class MsgQ:
|
|||||||
def process_command_subscribe(self, sock, routing, data):
|
def process_command_subscribe(self, sock, routing, data):
|
||||||
group = routing["group"]
|
group = routing["group"]
|
||||||
instance = routing["instance"]
|
instance = routing["instance"]
|
||||||
subtype = routing["subtype"]
|
if group == None or instance == None:
|
||||||
if group == None or instance == None or subtype == None:
|
|
||||||
return # ignore invalid packets entirely
|
return # ignore invalid packets entirely
|
||||||
self.subs.subscribe(group, instance, sock)
|
self.subs.subscribe(group, instance, sock)
|
||||||
|
|
||||||
|
@@ -31,6 +31,7 @@ class Session:
|
|||||||
self._recvlength = None
|
self._recvlength = None
|
||||||
self._sendbuffer = bytearray()
|
self._sendbuffer = bytearray()
|
||||||
self._sequence = 1
|
self._sequence = 1
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
@@ -48,7 +49,14 @@ class Session:
|
|||||||
def lname(self):
|
def lname(self):
|
||||||
return self._lname
|
return self._lname
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._socket.close()
|
||||||
|
self._lname = None
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
def sendmsg(self, env, msg = None):
|
def sendmsg(self, env, msg = None):
|
||||||
|
if self._closed:
|
||||||
|
raise SessionError("Session has been closed.")
|
||||||
if type(env) == dict:
|
if type(env) == dict:
|
||||||
env = Message.to_wire(env)
|
env = Message.to_wire(env)
|
||||||
if type(msg) == dict:
|
if type(msg) == dict:
|
||||||
@@ -64,6 +72,8 @@ class Session:
|
|||||||
self._socket.send(msg)
|
self._socket.send(msg)
|
||||||
|
|
||||||
def recvmsg(self, nonblock = True):
|
def recvmsg(self, nonblock = True):
|
||||||
|
if self._closed:
|
||||||
|
raise SessionError("Session has been closed.")
|
||||||
data = self._receive_full_buffer(nonblock)
|
data = self._receive_full_buffer(nonblock)
|
||||||
if data and len(data) > 2:
|
if data and len(data) > 2:
|
||||||
header_length = struct.unpack('>H', data[0:2])[0]
|
header_length = struct.unpack('>H', data[0:2])[0]
|
||||||
|
@@ -15,6 +15,10 @@ class TestCCWireEncoding(unittest.TestCase):
|
|||||||
self.s1 = ISC.CC.Session()
|
self.s1 = ISC.CC.Session()
|
||||||
self.s2 = ISC.CC.Session()
|
self.s2 = ISC.CC.Session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.s1.close()
|
||||||
|
self.s2.close()
|
||||||
|
|
||||||
def test_lname(self):
|
def test_lname(self):
|
||||||
self.assertTrue(self.s1.lname)
|
self.assertTrue(self.s1.lname)
|
||||||
self.assertTrue(self.s2.lname)
|
self.assertTrue(self.s2.lname)
|
||||||
@@ -40,5 +44,17 @@ class TestCCWireEncoding(unittest.TestCase):
|
|||||||
msg, env = self.s2.group_recvmsg()
|
msg, env = self.s2.group_recvmsg()
|
||||||
self.assertFalse(env)
|
self.assertFalse(env)
|
||||||
|
|
||||||
|
def test_directed_recipient(self):
|
||||||
|
self.s1.group_subscribe("g1", "i1")
|
||||||
|
time.sleep(0.5)
|
||||||
|
outmsg = { "data" : "foo" }
|
||||||
|
self.s1.group_sendmsg(outmsg, "g4", "i4", self.s2.lname)
|
||||||
|
time.sleep(0.5)
|
||||||
|
msg, env = self.s2.group_recvmsg()
|
||||||
|
self.assertEqual(env["from"], self.s1.lname)
|
||||||
|
self.assertEqual(env["to"], self.s2.lname)
|
||||||
|
self.assertEqual(env["group"], "g4")
|
||||||
|
self.assertEqual(env["instance"], "i4")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Reference in New Issue
Block a user