diff --git a/src/lib/util/io/Makefile.am b/src/lib/util/io/Makefile.am index cbcd54d92f..698327d8e6 100644 --- a/src/lib/util/io/Makefile.am +++ b/src/lib/util/io/Makefile.am @@ -1,7 +1,11 @@ AM_CXXFLAGS = $(B10_CXXFLAGS) +AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib +AM_CPPFLAGS += $(BOOST_INCLUDES) + lib_LTLIBRARIES = libutil_io.la libutil_io_la_SOURCES = fd.h fd.cc fd_share.h fd_share.cc +libutil_io_la_SOURCES += socketsession.h socketsession.cc sockaddr_util.h libutil_io_la_CXXFLAGS = $(AM_CXXFLAGS) -fno-strict-aliasing CLEANFILES = *.gcno *.gcda diff --git a/src/lib/util/io/sockaddr_util.h b/src/lib/util/io/sockaddr_util.h new file mode 100644 index 0000000000..92ebe34901 --- /dev/null +++ b/src/lib/util/io/sockaddr_util.h @@ -0,0 +1,66 @@ +// Copyright (C) 2011 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +#ifndef __SOCKADDR_UTIL_H_ +#define __SOCKADDR_UTIL_H_ 1 + +#include + +// This definitions in this file are for the convenience of internal +// implementation and test code, and are not intended to be used publicly. +// The namespace "internal" indicates the intent. + +namespace isc { +namespace util { +namespace io { +namespace internal { + +inline socklen_t +getSALength(const struct sockaddr& sa) { + if (sa.sa_family == AF_INET) { + return (sizeof(struct sockaddr_in)); + } else { + assert(sa.sa_family == AF_INET6); + return (sizeof(struct sockaddr_in6)); + } +} + +// Lower level C-APIs require conversion between various variants of +// sockaddr's, which is not friendly with C++. The following templates +// are a shortcut of common workaround conversion in such cases. + +template +const struct sockaddr* +convertSockAddr(const SA_TYPE* sa) { + const void* p = sa; + return (static_cast(p)); +} + +template +struct sockaddr* +convertSockAddr(SA_TYPE* sa) { + void* p = sa; + return (static_cast(p)); +} + +} +} +} +} + +#endif // __SOCKADDR_UTIL_H_ + +// Local Variables: +// mode: c++ +// End: diff --git a/src/lib/util/io/socketsession.cc b/src/lib/util/io/socketsession.cc new file mode 100644 index 0000000000..ab79fd0645 --- /dev/null +++ b/src/lib/util/io/socketsession.cc @@ -0,0 +1,233 @@ +// Copyright (C) 2011 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +#include + +#include + +#include "fd_share.h" +#include "socketsession.h" +#include "sockaddr_util.h" + +using namespace std; + +namespace isc { +namespace util { +namespace io { + +using namespace internal; + +struct SocketSessionForwarder::ForwarderImpl { + ForwarderImpl() : buf_(512) {} + struct sockaddr_un sock_un_; + socklen_t sock_un_len_; + int fd_; + OutputBuffer buf_; +}; + +SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) : + impl_(NULL) +{ + ForwarderImpl impl; + if (sizeof(impl.sock_un_.sun_path) - 1 < unix_file.length()) { + isc_throw(SocketSessionError, + "File name for a UNIX domain socket is too long: " << + unix_file); + } + impl.sock_un_.sun_family = AF_UNIX; + strncpy(impl.sock_un_.sun_path, unix_file.c_str(), + sizeof(impl.sock_un_.sun_path)); + assert(impl.sock_un_.sun_path[sizeof(impl.sock_un_.sun_path) - 1] == '\0'); + impl.sock_un_len_ = 2 + unix_file.length(); +#ifdef HAVE_SA_LEN + impl.sock_un_.sun_len = sock_un_len_; +#endif + impl.fd_ = -1; + + impl_ = new ForwarderImpl; + *impl_ = impl; +} + +SocketSessionForwarder::~SocketSessionForwarder() { + if (impl_->fd_ != -1) { + close(); + } + delete impl_; +} + +void +SocketSessionForwarder::connectToReceptor() { + if (impl_->fd_ != -1) { + isc_throw(SocketSessionError, "Duplicate connect to UNIX domain " + "endpoint " << impl_->sock_un_.sun_path); + } + + impl_->fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (impl_->fd_ == -1) { + isc_throw(SocketSessionError, "Failed to create a UNIX domain socket: " + << strerror(errno)); + } + if (connect(impl_->fd_, convertSockAddr(&impl_->sock_un_), + impl_->sock_un_len_) == -1) { + close(); // note: this is the internal method, not ::close() + isc_throw(SocketSessionError, "Failed to connect to UNIX domain " + "endpoint " << impl_->sock_un_.sun_path << ": " << + strerror(errno)); + } + int bufsize = 65536 * 2; + if (setsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, + sizeof(bufsize)) == -1) { + isc_throw(SocketSessionError, "failed to enlarge receive buffer size"); + } +} + +void +SocketSessionForwarder::close() { + if (impl_->fd_ == -1) { + isc_throw(SocketSessionError, "Attempt of close before connect"); + } + ::close(impl_->fd_); + impl_->fd_ = -1; +} + +void +SocketSessionForwarder::push(int sock, int family, int sock_type, int protocol, + const struct sockaddr& local_end, + const struct sockaddr& remote_end, + const void* data, size_t data_len) +{ + // check state (fd must be valid) + // family must be AF_INET or AF_INET6 + // sa_family should match + + send_fd(impl_->fd_, sock); + // TODO: error check + + impl_->buf_.clear(); + // Leave the space for the header length + impl_->buf_.skip(sizeof(uint16_t)); + // Socket properties: family, type, protocol + impl_->buf_.writeUint32(static_cast(family)); + impl_->buf_.writeUint32(static_cast(sock_type)); + impl_->buf_.writeUint32(static_cast(protocol)); + // Local endpoint + impl_->buf_.writeUint32(static_cast(getSALength(local_end))); + impl_->buf_.writeData(&local_end, getSALength(local_end)); + // Remote endpoint + impl_->buf_.writeUint32(static_cast(getSALength(remote_end))); + impl_->buf_.writeData(&remote_end, getSALength(remote_end)); + // Data length + impl_->buf_.writeUint32(static_cast(data_len)); + // Write the resulting header length at the beginning of the buffer + impl_->buf_.writeUint16At(impl_->buf_.getLength() - sizeof(uint16_t), 0); + + const int cc = write(impl_->fd_, impl_->buf_.getData(), + impl_->buf_.getLength()); + assert(cc == impl_->buf_.getLength()); + + const int cc_data = write(impl_->fd_, data, data_len); + assert(cc_data == data_len); +} + +SocketSession::SocketSession(int sock, int family, int type, int protocol, + const sockaddr* local_end, + const sockaddr* remote_end, + size_t data_len, const void* data) : + sock_(sock), family_(family), type_(type), protocol_(protocol), + local_end_(local_end), remote_end_(remote_end), + data_len_(data_len), data_(data) +{ + // TODO: local_end and remote_end must not be NULL; check it +} + +const size_t DEFAULT_HEADER_BUFLEN = sizeof(struct sockaddr_storage) * 2 + + sizeof(uint32_t) * 6; + +struct SocketSessionReceptor::ReceptorImpl { + ReceptorImpl(int fd) : fd_(fd), + sa_local_(convertSockAddr(&ss_local_)), + sa_remote_(convertSockAddr(&ss_remote_)), + header_buf_(DEFAULT_HEADER_BUFLEN), data_buf_(512) + {} + + const int fd_; + struct sockaddr_storage ss_local_; // placeholder + struct sockaddr* const sa_local_; + struct sockaddr_storage ss_remote_; // placeholder + struct sockaddr* const sa_remote_; + + vector header_buf_; + vector data_buf_; +}; + +SocketSessionReceptor::SocketSessionReceptor(int fd) : + impl_(new ReceptorImpl(fd)) +{ +} + +SocketSessionReceptor::~SocketSessionReceptor() { + delete impl_; +} + +SocketSession +SocketSessionReceptor::pop() { + const int passed_fd = recv_fd(impl_->fd_); + // TODO: error check + + uint16_t header_len; + const int cc = read(impl_->fd_, &header_len, sizeof(header_len)); + assert(cc == sizeof(header_len)); // XXX + header_len = InputBuffer(&header_len, sizeof(header_len)).readUint16(); + impl_->header_buf_.clear(); + impl_->header_buf_.resize(header_len); + read(impl_->fd_, &impl_->header_buf_[0], header_len); + + InputBuffer ibuffer(&impl_->header_buf_[0], header_len); + const int family = static_cast(ibuffer.readUint32()); + const int type = static_cast(ibuffer.readUint32()); + const int protocol = static_cast(ibuffer.readUint32()); + const socklen_t local_end_len = ibuffer.readUint32(); + assert(local_end_len <= sizeof(impl_->ss_local_)); // XXX + ibuffer.readData(&impl_->ss_local_, local_end_len); + const socklen_t remote_end_len = ibuffer.readUint32(); + assert(remote_end_len <= sizeof(impl_->ss_remote_)); // XXX + ibuffer.readData(&impl_->ss_remote_, remote_end_len); + const size_t data_len = ibuffer.readUint32(); + + impl_->data_buf_.clear(); + impl_->data_buf_.resize(data_len); + read(impl_->fd_, &impl_->data_buf_[0], data_len); + + return (SocketSession(passed_fd, family, type, protocol, + impl_->sa_local_, impl_->sa_remote_, data_len, + &impl_->data_buf_[0])); +} + +} +} +} diff --git a/src/lib/util/io/socketsession.h b/src/lib/util/io/socketsession.h new file mode 100644 index 0000000000..3d8de740a5 --- /dev/null +++ b/src/lib/util/io/socketsession.h @@ -0,0 +1,97 @@ +// Copyright (C) 2011 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +#ifndef __SOCKETSESSION_H_ +#define __SOCKETSESSION_H_ 1 + +#include + +#include + +#include + +namespace isc { +namespace util { +namespace io { + +class SocketSessionError: public Exception { +public: + SocketSessionError(const char *file, size_t line, const char *what): + isc::Exception(file, line, what) {} +}; + +class SocketSessionForwarder : boost::noncopyable { +public: + explicit SocketSessionForwarder(const std::string& unix_file); + ~SocketSessionForwarder(); + + void connectToReceptor(); + + void close(); + + void push(int sock, int family, int sock_type, int protocol, + const struct sockaddr& local_end, + const struct sockaddr& remote_end, + const void* data, size_t data_len); + +private: + struct ForwarderImpl; + ForwarderImpl* impl_; +}; + +class SocketSession { +public: + SocketSession(int sock, int family, int type, int protocol, + const sockaddr* local_end, const sockaddr* remote_end, + size_t data_len, const void* data); + int getSocket() const { return (sock_); } + int getFamily() const { return (family_); } + int getType() const { return (type_); } + int getProtocol() const { return (protocol_); } + const sockaddr& getLocalEndpoint() const { return (*local_end_); } + const sockaddr& getRemoteEndpoint() const { return (*remote_end_); } + const void* getData() const { return (data_); } + size_t getDataLength() const { return (data_len_); } + +private: + const int sock_; + const int family_; + const int type_; + const int protocol_; + const sockaddr* local_end_; + const sockaddr* remote_end_; + const size_t data_len_; + const void* const data_; +}; + +class SocketSessionReceptor : boost::noncopyable { +public: + explicit SocketSessionReceptor(int fd); + ~SocketSessionReceptor(); + SocketSession pop(); + +private: + struct ReceptorImpl; + ReceptorImpl* impl_; +}; + +} +} +} + +#endif // __SOCKETSESSION_H_ + +// Local Variables: +// mode: c++ +// End: diff --git a/src/lib/util/tests/Makefile.am b/src/lib/util/tests/Makefile.am index 47243f8273..98d90d013a 100644 --- a/src/lib/util/tests/Makefile.am +++ b/src/lib/util/tests/Makefile.am @@ -2,6 +2,7 @@ SUBDIRS = . AM_CPPFLAGS = -I$(top_builddir)/src/lib -I$(top_srcdir)/src/lib AM_CPPFLAGS += $(BOOST_INCLUDES) +AM_CPPFLAGS += -DTEST_DATA_BUILDDIR=\"$(abs_builddir)\" AM_CXXFLAGS = $(B10_CXXFLAGS) if USE_STATIC_LINK @@ -26,6 +27,7 @@ run_unittests_SOURCES += lru_list_unittest.cc run_unittests_SOURCES += qid_gen_unittest.cc run_unittests_SOURCES += random_number_generator_unittest.cc run_unittests_SOURCES += sha1_unittest.cc +run_unittests_SOURCES += socketsession_unittest.cc run_unittests_SOURCES += strutil_unittest.cc run_unittests_SOURCES += time_utilities_unittest.cc diff --git a/src/lib/util/tests/socketsession_unittest.cc b/src/lib/util/tests/socketsession_unittest.cc new file mode 100644 index 0000000000..851e274696 --- /dev/null +++ b/src/lib/util/tests/socketsession_unittest.cc @@ -0,0 +1,434 @@ +// Copyright (C) 2011 Internet Systems Consortium, Inc. ("ISC") +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +// PERFORMANCE OF THIS SOFTWARE. + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include + +#include + +#include + +#include +#include + +using namespace std; +using namespace isc::util::io; +using namespace isc::util::io::internal; + +namespace { + +const char* const TEST_UNIX_FILE = TEST_DATA_BUILDDIR "/test.unix"; +const char* const TEST_PORT = "53535"; +const char TEST_DATA[] = "BIND10 test"; + +// A simple helper structure to automatically close test sockets on return +// or exception in a RAII manner. non copyable to prevent duplicate close. +struct ScopedSocket : boost::noncopyable { + ScopedSocket() : fd(-1) {} + ScopedSocket(int sock) : fd(sock) {} + ~ScopedSocket() { + closeSocket(); + } + void reset(int sock) { + closeSocket(); + fd = sock; + } + int fd; +private: + void closeSocket() { + if (fd >= 0) { + close(fd); + } + } +}; + +// A helper function that makes a test socket non block so that a certain +// kind of test failure (such as missing send) won't cause hangup. +void +setNonBlock(int s, bool on) { + int fcntl_flags = fcntl(s, F_GETFL, 0); + if (on) { + fcntl_flags |= O_NONBLOCK; + } else { + fcntl_flags &= ~O_NONBLOCK; + } + if (fcntl(s, F_SETFL, fcntl_flags) == -1) { + isc_throw(isc::Unexpected, "fcntl(O_NONBLOCK) failed: " << + strerror(errno)); + } +} + +// A helper to impose some reasonable amount of wait on recv(from) +// if possible. It returns an option flag to be set for the system call +// (when necessary). +int +setRecvDelay(int s) { + const struct timeval timeo = { 10, 0 }; + if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &timeo, sizeof(timeo)) == -1) { + if (errno == ENOPROTOOPT) { + // Workaround for Solaris: see recursive_query_unittest + return (MSG_DONTWAIT); + } else { + isc_throw(isc::Unexpected, "set RCVTIMEO failed: " << + strerror(errno)); + } + } + return (0); +} + +class ForwarderTest : public ::testing::Test { +protected: + ForwarderTest() : listen_fd_(-1), forwarder_(TEST_UNIX_FILE), + large_text_(65535, 'a'), + test_un_len_(2 + strlen(TEST_UNIX_FILE)) + { + test_un_.sun_family = AF_UNIX; + strncpy(test_un_.sun_path, TEST_UNIX_FILE, sizeof(test_un_.sun_path)); +#ifdef HAVE_SA_LEN + test_un_.sun_len = test_un_len_; +#endif + } + + ~ForwarderTest() { + if (listen_fd_ != -1) { + close(listen_fd_); + } + unlink(TEST_UNIX_FILE); + + vector::const_iterator it; + for (it = addrinfo_list_.begin(); it != addrinfo_list_.end(); ++it) { + freeaddrinfo(*it); + } + } + + // Start an internal "socket session server". + void startListen() { + if (listen_fd_ != -1) { + isc_throw(isc::Unexpected, "duplicate call to startListen()"); + } + listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd_ == -1) { + isc_throw(isc::Unexpected, "failed to create UNIX domain socket" << + strerror(errno)); + } + if (bind(listen_fd_, convertSockAddr(&test_un_), test_un_len_) == -1) { + isc_throw(isc::Unexpected, "failed to bind UNIX domain socket" << + strerror(errno)); + } + // 10 is an arbitrary choice, should be sufficient for a single test + if (listen(listen_fd_, 10) == -1) { + isc_throw(isc::Unexpected, "failed to listen on UNIX domain socket" + << strerror(errno)); + } + } + + // Accept a new connection from a SocketSessionForwarder and return + // the socket FD of the new connection. This assumes startListen() + // has been called. + int acceptForwarder() { + setNonBlock(listen_fd_, true); // prevent the test from hanging up + struct sockaddr_un from; + socklen_t from_len = sizeof(from); + const int s = accept(listen_fd_, convertSockAddr(&from), &from_len); + if (s == -1) { + isc_throw(isc::Unexpected, "accept failed: " << strerror(errno)); + } + return (s); + } + + typedef pair SockAddrInfo; + SockAddrInfo getSockAddr(const string& addr_str, const string& port_str) { + struct addrinfo hints, *res; + memset(&hints, 0, sizeof(hints)); + hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV; + EXPECT_EQ(0, getaddrinfo(addr_str.c_str(), port_str.c_str(), NULL, + &res)); + addrinfo_list_.push_back(res); + return (SockAddrInfo(*res->ai_addr, res->ai_addrlen)); + } + + // A helper method that creates a specified type of socket that is + // supposed to be passed via a SocketSessionForwarder. It will bound + // to the specified address and port in sainfo. If do_listen is true + // and it's a TCP socket, it will also start listening to new connection + // requests. + int createSocket(int family, int type, int protocol, + const SockAddrInfo& sainfo, bool do_listen) + { + int s = socket(family, type, protocol); + if (s < 0) { + isc_throw(isc::Unexpected, "socket(2) failed: " << + strerror(errno)); + } + const int on = 1; + if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) { + isc_throw(isc::Unexpected, "setsockopt(SO_REUSEADDR) failed: " << + strerror(errno)); + } + if (bind(s, &sainfo.first, sainfo.second) < 0) { + close(s); + isc_throw(isc::Unexpected, "bind(2) failed: " << + strerror(errno)); + } + if (do_listen && protocol == IPPROTO_TCP) { + if (listen(s, 1) == -1) { + isc_throw(isc::Unexpected, "listen(2) failed: " << + strerror(errno)); + } + } + return (s); + } + + void checkPushAndPop(int family, int type, int protocoal, + const SockAddrInfo& local, + const SockAddrInfo& remote, const char* const data, + size_t data_len, bool new_connection); + +protected: + int listen_fd_; + SocketSessionForwarder forwarder_; + ScopedSocket accept_sock_; + const string large_text_; + +private: + struct sockaddr_un test_un_; + const socklen_t test_un_len_; + vector addrinfo_list_; +}; + +TEST_F(ForwarderTest, construct) { + // On construction the existence of the file doesn't matter. + SocketSessionForwarder("some_file"); + + // But too long a path should be rejected + struct sockaddr_un s; // can't be const; some compiler complains + EXPECT_THROW(SocketSessionForwarder(string(sizeof(s.sun_path), 'x')), + SocketSessionError); + // If it's one byte shorter it should be okay + SocketSessionForwarder(string(sizeof(s.sun_path) - 1, 'x')); +} + +TEST_F(ForwarderTest, connect) { + // File doesn't exist (we assume the file "no_such_file" doesn't exist) + SocketSessionForwarder forwarder("no_such_file"); + EXPECT_THROW(forwarder.connectToReceptor(), SocketSessionError); + // The socket should be closed internally, so close() should result in + // error. + EXPECT_THROW(forwarder.close(), SocketSessionError); + + // Set up the receptor and connect. It should succeed. + SocketSessionForwarder forwarder2(TEST_UNIX_FILE); + startListen(); + forwarder2.connectToReceptor(); + // And it can be closed successfully. + forwarder2.close(); + // Duplicate close should fail + EXPECT_THROW(forwarder2.close(), SocketSessionError); + // Once closed, reconnect is okay. + forwarder2.connectToReceptor(); + forwarder2.close(); + + // Duplicate connect should be rejected + forwarder2.connectToReceptor(); + EXPECT_THROW(forwarder2.connectToReceptor(), SocketSessionError); + + // Connect then destroy. Should be internally closed, but unfortunately + // it's not easy to test it directly. We only check no disruption happens. + SocketSessionForwarder* forwarderp = + new SocketSessionForwarder(TEST_UNIX_FILE); + forwarderp->connectToReceptor(); + delete forwarderp; +} + +TEST_F(ForwarderTest, close) { + // can't close before connect + EXPECT_THROW(SocketSessionForwarder(TEST_UNIX_FILE).close(), + SocketSessionError); +} + +void +checkSockAddrs(const sockaddr& expected, const sockaddr& actual) { + char hbuf_expected[NI_MAXHOST], sbuf_expected[NI_MAXSERV], + hbuf_actual[NI_MAXHOST], sbuf_actual[NI_MAXSERV]; + EXPECT_EQ(0, getnameinfo(&expected, getSALength(expected), + hbuf_expected, sizeof(hbuf_expected), + sbuf_expected, sizeof(sbuf_expected), + NI_NUMERICHOST | NI_NUMERICSERV)); + EXPECT_EQ(0, getnameinfo(&actual, getSALength(actual), + hbuf_actual, sizeof(hbuf_actual), + sbuf_actual, sizeof(sbuf_actual), + NI_NUMERICHOST | NI_NUMERICSERV)); + EXPECT_EQ(string(hbuf_expected), string(hbuf_actual)); + EXPECT_EQ(string(sbuf_expected), string(sbuf_actual)); +} + +void +ForwarderTest::checkPushAndPop(int family, int type, int protocol, + const SockAddrInfo& local, + const SockAddrInfo& remote, + const char* const data, + size_t data_len, bool new_connection) +{ + // Create an original socket to be passed + const ScopedSocket sock(createSocket(family, type, protocol, local, true)); + int fwd_fd = sock.fd; // default FD to be forwarded + ScopedSocket client_sock; // for TCP test we need a separate "client".. + ScopedSocket server_sock; // ..and a separate socket for the connection + if (protocol == IPPROTO_TCP) { + // Use unspecified port for the "client" to avoid bind(2) failure + const SockAddrInfo client_addr = getSockAddr(family == AF_INET6 ? + "::1" : "127.0.0.1", "0"); + client_sock.reset(createSocket(family, type, protocol, client_addr, + false)); + setNonBlock(client_sock.fd, true); + // This connect would "fail" due to EINPROGRESS. Ignore it for now. + connect(client_sock.fd, &local.first, local.second); + sockaddr_storage ss; + socklen_t salen = sizeof(ss); + server_sock.reset(accept(sock.fd, convertSockAddr(&ss), &salen)); + if (server_sock.fd == -1) { + isc_throw(isc::Unexpected, "internal accept failed: " << + strerror(errno)); + } + fwd_fd = server_sock.fd; + } + + // If a new connection is required, start the "server", have the + // internal forwarder connect to it, and then internally accept it. + if (new_connection) { + startListen(); + forwarder_.connectToReceptor(); + accept_sock_.reset(acceptForwarder()); + setNonBlock(accept_sock_.fd, true); + } + + // Then push one socket session via the forwarder. + forwarder_.push(fwd_fd, family, type, protocol, local.first, remote.first, + data, data_len); + + // Pop the socket session we just pushed from a local receptor, and + // check the content + SocketSessionReceptor receptor(accept_sock_.fd); + SocketSession sock_session = receptor.pop(); + const ScopedSocket passed_sock(sock_session.getSocket()); + EXPECT_LE(0, passed_sock.fd); + // The passed FD should be different from the original FD + EXPECT_NE(fwd_fd, passed_sock.fd); + EXPECT_EQ(family, sock_session.getFamily()); + EXPECT_EQ(type, sock_session.getType()); + EXPECT_EQ(protocol, sock_session.getProtocol()); + checkSockAddrs(local.first, sock_session.getLocalEndpoint()); + checkSockAddrs(remote.first, sock_session.getRemoteEndpoint()); + ASSERT_EQ(data_len, sock_session.getDataLength()); + EXPECT_EQ(0, memcmp(data, sock_session.getData(), data_len)); + + // Check if the passed FD is usable by sending some data from it + setNonBlock(passed_sock.fd, false); + if (protocol == IPPROTO_UDP) { + EXPECT_EQ(sizeof(TEST_DATA), + sendto(passed_sock.fd, TEST_DATA, sizeof(TEST_DATA), 0, + convertSockAddr(&local.first), local.second)); + } else { + server_sock.reset(-1); + EXPECT_EQ(sizeof(TEST_DATA), + send(passed_sock.fd, TEST_DATA, sizeof(TEST_DATA), 0)); + } + char recvbuf[sizeof(TEST_DATA)]; + sockaddr_storage ss; + socklen_t sa_len = sizeof(ss); + if (protocol == IPPROTO_UDP) { + EXPECT_EQ(sizeof(recvbuf), + recvfrom(fwd_fd, recvbuf, sizeof(recvbuf), + setRecvDelay(fwd_fd), convertSockAddr(&ss), + &sa_len)); + } else { + setNonBlock(client_sock.fd, false); + EXPECT_EQ(sizeof(recvbuf), + recv(client_sock.fd, recvbuf, sizeof(recvbuf), + setRecvDelay(client_sock.fd))); + } + EXPECT_EQ(string(TEST_DATA), string(recvbuf)); +} + +TEST_F(ForwarderTest, pushAndPopUDP) { + // Pass a UDP/IPv6 session. + const SockAddrInfo sai_local6(getSockAddr("::1", TEST_PORT)); + const SockAddrInfo sai_remote6(getSockAddr("2001:db8::1", "5300")); + { + SCOPED_TRACE("Passing UDP/IPv6 session"); + checkPushAndPop(AF_INET6, SOCK_DGRAM, IPPROTO_UDP, sai_local6, + sai_remote6, TEST_DATA, sizeof(TEST_DATA), true); + } + { + SCOPED_TRACE("Passing TCP/IPv6 session"); + checkPushAndPop(AF_INET6, SOCK_STREAM, IPPROTO_TCP, sai_local6, + sai_remote6, TEST_DATA, sizeof(TEST_DATA), false); + } + + // Pass a UDP/IPv4 session. This reuses the same pair of forwarder and + // acceptor, which should be usable for multiple attempts of passing, + // regardless of family of the passed session + const SockAddrInfo sai_local4(getSockAddr("127.0.0.1", TEST_PORT)); + const SockAddrInfo sai_remote4(getSockAddr("192.0.2.2", "5300")); + { + SCOPED_TRACE("Passing UDP/IPv4 session"); + checkPushAndPop(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local4, + sai_remote4, TEST_DATA, sizeof(TEST_DATA), false); + } + { + SCOPED_TRACE("Passing TCP/IPv4 session"); + checkPushAndPop(AF_INET, SOCK_STREAM, IPPROTO_TCP, sai_local4, + sai_remote4, TEST_DATA, sizeof(TEST_DATA), false); + } + + // Also try large data + { + SCOPED_TRACE("Passing UDP/IPv6 session with large data"); + checkPushAndPop(AF_INET6, SOCK_DGRAM, IPPROTO_UDP, sai_local6, + sai_remote6, large_text_.c_str(), large_text_.length(), + false); + } + { + SCOPED_TRACE("Passing TCP/IPv6 session with large data"); + checkPushAndPop(AF_INET6, SOCK_STREAM, IPPROTO_TCP, sai_local6, + sai_remote6, large_text_.c_str(), large_text_.length(), + false); + } + { + SCOPED_TRACE("Passing UDP/IPv4 session with large data"); + checkPushAndPop(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local4, + sai_remote4, large_text_.c_str(), large_text_.length(), + false); + } + { + SCOPED_TRACE("Passing TCP/IPv4 session with large data"); + checkPushAndPop(AF_INET, SOCK_STREAM, IPPROTO_TCP, sai_local4, + sai_remote4, large_text_.c_str(), large_text_.length(), + false); + } +} + +}