diff --git a/src/bin/pymsgq/msgq.py b/src/bin/pymsgq/msgq.py index 15d4d297ce..c4d0fb0689 100644 --- a/src/bin/pymsgq/msgq.py +++ b/src/bin/pymsgq/msgq.py @@ -87,6 +87,7 @@ class MsgQ: self.connection_counter = random.random() self.hostname = socket.gethostname() self.subs = SubscriptionManager() + self.lnames = {} def setup_poller(self): """Set up the poll thing. Internal function.""" @@ -117,6 +118,8 @@ class MsgQ: newsocket, ipaddr = self.listen_socket.accept() sys.stderr.write("Connection\n") self.sockets[newsocket.fileno()] = newsocket + lname = self.newlname() + self.lnames[lname] = newsocket self.poller.register(newsocket, select.POLLIN) def process_socket(self, fd): @@ -132,6 +135,8 @@ class MsgQ: """Fully close down the socket.""" self.poller.unregister(sock) self.subs.unsubscribe_all(sock) + lname = [ k for k, v in self.lnames.items() if v == sock ][0] + del self.lnames[lname] sock.close() self.sockets[fd] = None sys.stderr.write("Closing socket fd %d\n" % fd) @@ -232,16 +237,20 @@ class MsgQ: return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname) def process_command_getlname(self, sock, routing, data): - env = { "type" : "getlname" } - reply = { "lname" : self.newlname() } - self.sendmsg(sock, env, reply) + lname = [ k for k, v in self.lnames.items() if v == sock ][0] + self.sendmsg(sock, { "type" : "getlname" }, { "lname" : lname }) def process_command_send(self, sock, routing, data): group = routing["group"] instance = routing["instance"] + to = routing["to"] if group == None or instance == None: return # ignore invalid packets entirely - sockets = self.subs.find(group, instance) + + if to == "*": + sockets = self.subs.find(group, instance) + else: + sockets = [ self.lnames[to] ] msg = self.preparemsg(routing, data) @@ -253,8 +262,7 @@ class MsgQ: def process_command_subscribe(self, sock, routing, data): group = routing["group"] instance = routing["instance"] - subtype = routing["subtype"] - if group == None or instance == None or subtype == None: + if group == None or instance == None: return # ignore invalid packets entirely self.subs.subscribe(group, instance, sock) diff --git a/src/lib/cc/python/ISC/CC/session.py b/src/lib/cc/python/ISC/CC/session.py index 58e0cff2ba..81a4b5c5e1 100644 --- a/src/lib/cc/python/ISC/CC/session.py +++ b/src/lib/cc/python/ISC/CC/session.py @@ -31,6 +31,7 @@ class Session: self._recvlength = None self._sendbuffer = bytearray() self._sequence = 1 + self._closed = False try: self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -48,7 +49,14 @@ class Session: def lname(self): return self._lname + def close(self): + self._socket.close() + self._lname = None + self._closed = True + def sendmsg(self, env, msg = None): + if self._closed: + raise SessionError("Session has been closed.") if type(env) == dict: env = Message.to_wire(env) if type(msg) == dict: @@ -64,6 +72,8 @@ class Session: self._socket.send(msg) def recvmsg(self, nonblock = True): + if self._closed: + raise SessionError("Session has been closed.") data = self._receive_full_buffer(nonblock) if data and len(data) > 2: header_length = struct.unpack('>H', data[0:2])[0] diff --git a/src/lib/cc/python/test_session.py b/src/lib/cc/python/test_session.py index 0f060657e0..c9daf6a7b7 100644 --- a/src/lib/cc/python/test_session.py +++ b/src/lib/cc/python/test_session.py @@ -15,6 +15,10 @@ class TestCCWireEncoding(unittest.TestCase): self.s1 = ISC.CC.Session() self.s2 = ISC.CC.Session() + def tearDown(self): + self.s1.close() + self.s2.close() + def test_lname(self): self.assertTrue(self.s1.lname) self.assertTrue(self.s2.lname) @@ -40,5 +44,17 @@ class TestCCWireEncoding(unittest.TestCase): msg, env = self.s2.group_recvmsg() self.assertFalse(env) + def test_directed_recipient(self): + self.s1.group_subscribe("g1", "i1") + time.sleep(0.5) + outmsg = { "data" : "foo" } + self.s1.group_sendmsg(outmsg, "g4", "i4", self.s2.lname) + time.sleep(0.5) + msg, env = self.s2.group_recvmsg() + self.assertEqual(env["from"], self.s1.lname) + self.assertEqual(env["to"], self.s2.lname) + self.assertEqual(env["group"], "g4") + self.assertEqual(env["instance"], "i4") + if __name__ == '__main__': unittest.main()