diff --git a/src/bin/auth/auth_srv.cc b/src/bin/auth/auth_srv.cc index 2e1df09401..eaafaae873 100644 --- a/src/bin/auth/auth_srv.cc +++ b/src/bin/auth/auth_srv.cc @@ -78,6 +78,31 @@ using namespace isc::asiolink; using namespace isc::asiodns; using namespace isc::server_common::portconfig; +namespace { +// A helper class for cleaning up message renderer. +// +// A temporary object of this class is expected to be created before starting +// response message rendering. On construction, it (re)initialize the given +// message renderer with the given buffer. On destruction, it releases +// the previously set buffer and then release any internal resource in the +// renderer, no matter what happened during the rendering, especially even +// when it resulted in an exception. +class RendererHolder { +public: + RendererHolder(MessageRenderer& renderer, OutputBuffer* buffer) : + renderer_(renderer) + { + renderer.setBuffer(buffer); + } + ~RendererHolder() { + renderer_.setBuffer(NULL); + renderer_.clear(); + } +private: + MessageRenderer& renderer_; +}; +} + class AuthSrvImpl { private: // prohibit copy @@ -277,8 +302,8 @@ public: }; void -makeErrorMessage(Message& message, OutputBuffer& buffer, - const Rcode& rcode, +makeErrorMessage(MessageRenderer& renderer, Message& message, + OutputBuffer& buffer, const Rcode& rcode, std::auto_ptr tsig_context = std::auto_ptr()) { @@ -311,14 +336,12 @@ makeErrorMessage(Message& message, OutputBuffer& buffer, message.setRcode(rcode); - MessageRenderer renderer; - renderer.setBuffer(&buffer); + RendererHolder holder(renderer, &buffer); if (tsig_context.get() != NULL) { message.toWire(renderer, *tsig_context); } else { message.toWire(renderer); } - renderer.setBuffer(NULL); LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_ERROR_RESPONSE) .arg(renderer.getLength()).arg(message); } @@ -447,13 +470,13 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message, } catch (const DNSProtocolError& error) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PROTOCOL_ERROR) .arg(error.getRcode().toText()).arg(error.what()); - makeErrorMessage(message, buffer, error.getRcode()); + makeErrorMessage(impl_->renderer_, message, buffer, error.getRcode()); impl_->resumeServer(server, message, true); return; } catch (const Exception& ex) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PARSE_ERROR) .arg(ex.what()); - makeErrorMessage(message, buffer, Rcode::SERVFAIL()); + makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL()); impl_->resumeServer(server, message, true); return; } // other exceptions will be handled at a higher layer. @@ -480,7 +503,8 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message, } if (tsig_error != TSIGError::NOERROR()) { - makeErrorMessage(message, buffer, tsig_error.toRcode(), tsig_context); + makeErrorMessage(impl_->renderer_, message, buffer, + tsig_error.toRcode(), tsig_context); impl_->resumeServer(server, message, true); return; } @@ -497,9 +521,11 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message, } else if (message.getOpcode() != Opcode::QUERY()) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_UNSUPPORTED_OPCODE) .arg(message.getOpcode().toText()); - makeErrorMessage(message, buffer, Rcode::NOTIMP(), tsig_context); + makeErrorMessage(impl_->renderer_, message, buffer, + Rcode::NOTIMP(), tsig_context); } else if (message.getRRCount(Message::SECTION_QUESTION) != 1) { - makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context); + makeErrorMessage(impl_->renderer_, message, buffer, + Rcode::FORMERR(), tsig_context); } else { ConstQuestionPtr question = *message.beginQuestion(); const RRType &qtype = question->getType(); @@ -517,10 +543,10 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message, } catch (const std::exception& ex) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE) .arg(ex.what()); - makeErrorMessage(message, buffer, Rcode::SERVFAIL()); + makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL()); } catch (...) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE_UNKNOWN); - makeErrorMessage(message, buffer, Rcode::SERVFAIL()); + makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL()); } impl_->resumeServer(server, message, send_answer); } @@ -563,13 +589,11 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message, } } catch (const Exception& ex) { LOG_ERROR(auth_logger, AUTH_PROCESS_FAIL).arg(ex.what()); - makeErrorMessage(message, buffer, Rcode::SERVFAIL()); + makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL()); return (true); } - renderer_.clear(); - renderer_.setBuffer(&buffer); - + RendererHolder holder(renderer_, &buffer); const bool udp_buffer = (io_message.getSocket().getProtocol() == IPPROTO_UDP); renderer_.setLengthLimit(udp_buffer ? remote_bufsize : 65535); @@ -578,7 +602,6 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message, } else { message.toWire(renderer_); } - renderer_.setBuffer(NULL); LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_NORMAL_RESPONSE) .arg(renderer_.getLength()).arg(message); return (true); @@ -594,7 +617,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message, if (io_message.getSocket().getProtocol() == IPPROTO_UDP) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_UDP); - makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context); + makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(), + tsig_context); return (true); } @@ -619,7 +643,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message, LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_ERROR) .arg(err.what()); - makeErrorMessage(message, buffer, Rcode::SERVFAIL(), tsig_context); + makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL(), + tsig_context); return (true); } @@ -636,14 +661,16 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message, if (message.getRRCount(Message::SECTION_QUESTION) != 1) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_QUESTIONS) .arg(message.getRRCount(Message::SECTION_QUESTION)); - makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context); + makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(), + tsig_context); return (true); } ConstQuestionPtr question = *message.beginQuestion(); if (question->getType() != RRType::SOA()) { LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_RRTYPE) .arg(question->getType().toText()); - makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context); + makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(), + tsig_context); return (true); } @@ -698,14 +725,12 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message, message.setHeaderFlag(Message::HEADERFLAG_AA); message.setRcode(Rcode::NOERROR()); - renderer_.clear(); - renderer_.setBuffer(&buffer); + RendererHolder holder(renderer_, &buffer); if (tsig_context.get() != NULL) { message.toWire(renderer_, *tsig_context); } else { message.toWire(renderer_); } - renderer_.setBuffer(NULL); return (true); } diff --git a/src/bin/auth/tests/auth_srv_unittest.cc b/src/bin/auth/tests/auth_srv_unittest.cc index e13987a26d..70bc6e3cf7 100644 --- a/src/bin/auth/tests/auth_srv_unittest.cc +++ b/src/bin/auth/tests/auth_srv_unittest.cc @@ -1138,11 +1138,12 @@ checkThrow(ThrowWhen method, ThrowWhen throw_at, bool isc_exception) { class FakeZoneFinder : public isc::datasrc::ZoneFinder { public: FakeZoneFinder(isc::datasrc::ZoneFinderPtr zone_finder, - ThrowWhen throw_when, - bool isc_exception) : + ThrowWhen throw_when, bool isc_exception, + ConstRRsetPtr fake_rrset) : real_zone_finder_(zone_finder), throw_when_(throw_when), - isc_exception_(isc_exception) + isc_exception_(isc_exception), + fake_rrset_(fake_rrset) {} virtual isc::dns::Name @@ -1162,7 +1163,18 @@ public: const isc::dns::RRType& type, isc::datasrc::ZoneFinder::FindOptions options) { + using namespace isc::datasrc; checkThrow(THROW_AT_FIND, throw_when_, isc_exception_); + // If faked RRset was specified on construction and it matches the + // query, return it instead of searching the real data source. + if (fake_rrset_ && fake_rrset_->getName() == name && + fake_rrset_->getType() == type) + { + return (ZoneFinderContextPtr(new ZoneFinder::Context( + *this, options, + ResultContext(SUCCESS, + fake_rrset_)))); + } return (real_zone_finder_->find(name, type, options)); } @@ -1190,6 +1202,7 @@ private: isc::datasrc::ZoneFinderPtr real_zone_finder_; ThrowWhen throw_when_; bool isc_exception_; + ConstRRsetPtr fake_rrset_; }; /// \brief Proxy InMemoryClient that can throw exceptions at specified times @@ -1206,12 +1219,15 @@ public: /// class or the related FakeZoneFinder) /// \param isc_exception if true, throw isc::Exception, otherwise, /// throw std::exception + /// \param fake_rrset If non NULL, it will be used as an answer to + /// find() for that name and type. FakeInMemoryClient(AuthSrv::InMemoryClientPtr real_client, - ThrowWhen throw_when, - bool isc_exception) : + ThrowWhen throw_when, bool isc_exception, + ConstRRsetPtr fake_rrset = ConstRRsetPtr()) : real_client_(real_client), throw_when_(throw_when), - isc_exception_(isc_exception) + isc_exception_(isc_exception), + fake_rrset_(fake_rrset) {} /// \brief proxy call for findZone @@ -1226,14 +1242,16 @@ public: const FindResult result = real_client_->findZone(name); return (FindResult(result.code, isc::datasrc::ZoneFinderPtr( new FakeZoneFinder(result.zone_finder, - throw_when_, - isc_exception_)))); + throw_when_, + isc_exception_, + fake_rrset_)))); } private: AuthSrv::InMemoryClientPtr real_client_; ThrowWhen throw_when_; bool isc_exception_; + ConstRRsetPtr fake_rrset_; }; } // end anonymous namespace for throwing proxy classes @@ -1248,9 +1266,7 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) { AuthSrv::InMemoryClientPtr fake_client( new FakeInMemoryClient(server.getInMemoryClient(rrclass), - THROW_NEVER, - false)); - + THROW_NEVER, false)); ASSERT_NE(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass)); server.setInMemoryClient(rrclass, fake_client); @@ -1267,9 +1283,11 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) { // to throw in the given method // If isc_exception is true, it will throw isc::Exception, otherwise // it will throw std::exception +// If non null rrset is given, it will be passed to the proxy so it can +// return some faked response. void setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when, - bool isc_exception) + bool isc_exception, ConstRRsetPtr rrset = ConstRRsetPtr()) { // Set real inmem client to proxy updateConfig(server, config, true); @@ -1279,8 +1297,7 @@ setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when, AuthSrv::InMemoryClientPtr fake_client( new FakeInMemoryClient( server->getInMemoryClient(isc::dns::RRClass::IN()), - throw_when, - isc_exception)); + throw_when, isc_exception, rrset)); ASSERT_NE(AuthSrv::InMemoryClientPtr(), server->getInMemoryClient(isc::dns::RRClass::IN())); @@ -1324,4 +1341,45 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxyGetClass) { opcode.getCode(), QR_FLAG | AA_FLAG, 1, 1, 2, 1); } +TEST_F(AuthSrvTest, queryWithThrowingInToWire) { + // Set up a faked data source. It will return an empty RRset for the + // query. + ConstRRsetPtr empty_rrset(new RRset(Name("foo.example"), + RRClass::IN(), RRType::TXT(), + RRTTL(0))); + setupThrow(&server, CONFIG_INMEMORY_EXAMPLE, THROW_NEVER, true, + empty_rrset); + + // Repeat the query processing two times. Due to the faked RRset, + // toWire() should throw, and it should result in SERVFAIL. + OutputBufferPtr orig_buffer; + for (int i = 0; i < 2; ++i) { + UnitTestUtil::createDNSSECRequestMessage(request_message, opcode, + default_qid, + Name("foo.example."), + RRClass::IN(), RRType::TXT()); + createRequestPacket(request_message, IPPROTO_UDP); + server.processMessage(*io_message, *parse_message, *response_obuffer, + &dnsserv); + headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(), + opcode.getCode(), QR_FLAG, 1, 0, 0, 0); + + // Make a backup of the original buffer for latest tests and replace + // it with a new one + if (!orig_buffer) { + orig_buffer = response_obuffer; + response_obuffer.reset(new OutputBuffer(0)); + } + request_message.clear(Message::RENDER); + parse_message->clear(Message::PARSE); + } + + // Now check if the original buffer is intact + parse_message->clear(Message::PARSE); + InputBuffer ibuffer(orig_buffer->getData(), orig_buffer->getLength()); + parse_message->fromWire(ibuffer); + headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(), + opcode.getCode(), QR_FLAG, 1, 0, 0, 0); +} + }