2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-08-28 21:07:46 +00:00

Add backend provider (#980)

Introduces the concept of a `Backend` into VinylDNS.  This will allow support for any DNS backend in the future, including AwS Route53 for example.  This is consistent with other "provider" things for dynamic loading of classes (Notifier, Repository, Queue, etc.)

The initial implementation builds on what we have already, that is when creating a zone one can choose a `backendId` that is configured in the `application.conf`.  If no `backendId` is specified, we attempt to map like we do today, so the exact same functionality.

We expand that by allowing one to map a `backendId` to a different provider (like aws). 

After this PR:
1. If someone specifies a zone connection on a zone, it will work exactly like it does today, namely go through the `DnsBackend` to connect.
2. If someone specifies a `backendId` when setting up a zone, the naive mapping will take place to map that zone to the `Backend` implementation that is configured with that `backendId`.   For example, if you have configured a backend id `aws` that connects to Route53, and you specify `aws` when connecting the zone, it will connect to it in Route 53 **Note: we still do not support zone create, but that is much closer to reality with this PR, much much**
3. If someone specifies NEITHER, the `defaultBackendId` will be used, which could be on any one of the backend providers configured.

To start, there is a new `vinyldns.core.domain.backend` package that contains the main classes for the system.  In there you will find the following:

- `BackendProvider` - this is to be implemented by each provider.  Adds a means of pre-loading zones, and providing connections to zones. 
- `Backend` - provides connectivity to a particular backend instance.  For example, a particular DNS Authoritative server.  This is where the real work happens of interacting with whatever backend.  For example, `DnsConnection` implements this to send DDNS messages to the DNS system.  Consider this the "main" thing to implement, where the rubber meets the road, the meat and potatoes
- `BackendProviderLoader` - to be implemented by each provider, knows how to load it's single instance `BackendProvider`, as well as possibly pre-loading configured `Backends` or anything else it needs to do to get ready.  It provides a dynamic hook via the `def load` method that is called by the `BackendLoader` to load a specific `Backend`
- `BackendResolver` - the main, default, BackendResolver.  It holds all `BackendProvider` instances loaded via the `BackendLoader` and provides right now a naive lookup mechanism to find `Backend`s.  Really, this is more of a `Router` or `Resolver`, as in the future it could use more advanced techniques to finding connections than right now
- `BackendConfigs` - used by the `BackendRegistry` as the entrypoint into configuration for all backends
- `BackendProviderConfig` - a single backend provider configuration, specifies a `className` that should be the `BackendProviderLoader` implementation to be loaded, and a `settings` that is passed into the `BackendProvider` to load itself.  This is consistent with other providers.
- `BackendResponse` - uniform responses across all providers to the rest of the VinylDNS System

**Workflow**
During initialization of the system:

1. The `BackendResolver` loads the `BackendConfigs` from the application configuration.  This contains configuration for ALL backends
2. The `BackendResolver` utilizes the `BackendLoader` to dynamically load each backend individually.  If any backend cannot be loaded, it will fail.
3. The `BackendLoader` creates a new instance of each `className` for each `BackendConfig`, this points to the `BackendProviderLoader` implementation which takes care of loading the specific `BackendProvider` provided the configuration
4. The `BackendProviderLoader` does any initialization necessary to ensure it is ready.  In the case of `Route53`, it will pre-load and cache all hosted zones that are available for the AWS account that is configured.  For Route53, a single `Route53Backend` is setup right now.  For `DnsBackend`, a connection (server, port, tsig key) is setup for each DNS Authoritative system to integrate with.

During runtime of the system:

1. When anything is needed, the `BackendResolver` is consulted that will determine how to lookup the `Backend` that is needed.  This is done right now by naively scanning all `BackendProvider` instances it has to say "can anyone connect to this zone".  More intelligent discovery rules can be added in the future
2. Once a `Backend` is obtained, any operation can be performed:
    1. `ZoneConnectionValidator` uses `zoneExists` and `loadZone` to validate a zone is usable by VinylDNS
    2. `RecordSetChangeHandler` uses `resolve` and `applyChange` to apply changes to the DNS backend
    3. `ZoneSyncHandler` and `DnsZoneViewLoader` use `loadZone` in order to load records into VinylDNS

**What else is here**

- Provided an implementation of a backend provider for DNS via `Backend`
- Updated all of VinylDNS to use `Backends` instead of hard coded to DNS
- Provided an implementation of a backend provider for AWS Route 53 as an example to follow for other providers


**Example configuration**

```
vinyldns {
  backend {
    default-backend-id = "r53"

    backend-providers = [
      {
        class-name = "vinyldns.route53.backend.Route53BackendProviderLoader"
        settings = {
          backends = [
            {
              id = "test"
              access-key = "vinyldnsTest"
              secret-key = "notNeededForSnsLocal"
              service-endpoint = "http://127.0.0.1:19009"
              signing-region = "us-east-1"
            }
          ]
        }
      }
    ]
  }
}
```
This commit is contained in:
Paul Cleary 2020-09-30 09:17:32 -04:00 committed by GitHub
parent 20a3708c42
commit a988bcd9a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 2191 additions and 852 deletions

View File

@ -301,6 +301,21 @@ lazy val sqs = (project in file("modules/sqs"))
).dependsOn(core % "compile->compile;test->test") ).dependsOn(core % "compile->compile;test->test")
.settings(name := "sqs") .settings(name := "sqs")
lazy val r53 = (project in file("modules/r53"))
.enablePlugins(AutomateHeaderPlugin)
.configs(IntegrationTest)
.settings(sharedSettings)
.settings(headerSettings(IntegrationTest))
.settings(inConfig(IntegrationTest)(scalafmtConfigSettings))
.settings(corePublishSettings)
.settings(testSettings)
.settings(Defaults.itSettings)
.settings(libraryDependencies ++= r53Dependencies ++ commonTestDependencies.map(_ % "test, it"))
.settings(
organization := "io.vinyldns",
).dependsOn(core % "compile->compile;test->test")
.settings(name := "r53")
val preparePortal = TaskKey[Unit]("preparePortal", "Runs NPM to prepare portal for start") val preparePortal = TaskKey[Unit]("preparePortal", "Runs NPM to prepare portal for start")
val checkJsHeaders = TaskKey[Unit]("checkJsHeaders", "Runs script to check for APL 2.0 license headers") val checkJsHeaders = TaskKey[Unit]("checkJsHeaders", "Runs script to check for APL 2.0 license headers")
val createJsHeaders = TaskKey[Unit]("createJsHeaders", "Runs script to prepend APL 2.0 license headers to files") val createJsHeaders = TaskKey[Unit]("createJsHeaders", "Runs script to prepend APL 2.0 license headers to files")
@ -446,6 +461,7 @@ addCommandAlias("validate", "; root/clean; " +
"api/headerCheck api/test:headerCheck api/it:headerCheck " + "api/headerCheck api/test:headerCheck api/it:headerCheck " +
"dynamodb/headerCheck dynamodb/test:headerCheck dynamodb/it:headerCheck " + "dynamodb/headerCheck dynamodb/test:headerCheck dynamodb/it:headerCheck " +
"mysql/headerCheck mysql/test:headerCheck mysql/it:headerCheck " + "mysql/headerCheck mysql/test:headerCheck mysql/it:headerCheck " +
"r53/headerCheck r53/test:headerCheck r53/it:headerCheck " +
"sqs/headerCheck sqs/test:headerCheck sqs/it:headerCheck " + "sqs/headerCheck sqs/test:headerCheck sqs/it:headerCheck " +
"portal/headerCheck portal/test:headerCheck; " + "portal/headerCheck portal/test:headerCheck; " +
"portal/createJsHeaders;portal/checkJsHeaders;" + "portal/createJsHeaders;portal/checkJsHeaders;" +

View File

@ -31,8 +31,9 @@ services:
ports: ports:
- "19006:19006" - "19006:19006"
- "19007:19007" - "19007:19007"
- "19009:19009"
environment: environment:
- SERVICES=sns:19006,sqs:19007 - SERVICES=sns:19006,sqs:19007,route53:19009
- START_WEB=0 - START_WEB=0
mail: mail:

View File

@ -3894,7 +3894,7 @@ def test_create_batch_delete_record_for_invalid_record_data_fails(shared_zone_te
assert_failed_change_in_error_response(errors[0], input_name="delete-non-existent-record.ok.", record_data="1.1.1.1", change_type="DeleteRecordSet", assert_failed_change_in_error_response(errors[0], input_name="delete-non-existent-record.ok.", record_data="1.1.1.1", change_type="DeleteRecordSet",
error_messages=['Record "delete-non-existent-record.ok." Does Not Exist: cannot delete a record that does not exist.']) error_messages=['Record "delete-non-existent-record.ok." Does Not Exist: cannot delete a record that does not exist.'])
assert_failed_change_in_error_response(errors[1], input_name=a_delete_fqdn, record_data="4.5.6.7", change_type="DeleteRecordSet", assert_failed_change_in_error_response(errors[1], input_name=a_delete_fqdn, record_data="4.5.6.7", change_type="DeleteRecordSet",
error_messages=['Record data AData(4.5.6.7) does not exist for "' + a_delete_fqdn + '".']) error_messages=['Record data 4.5.6.7 does not exist for "' + a_delete_fqdn + '".'])
finally: finally:
clear_recordset_list(to_delete, client) clear_recordset_list(to_delete, client)

View File

@ -1,98 +0,0 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.domain.dns
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import org.xbill.DNS
import vinyldns.api.domain.dns.DnsProtocol.{DnsResponse, NoError}
import vinyldns.api.domain.record.RecordSetChangeGenerator
import vinyldns.core.domain.zone.{Zone, ZoneConnection, ZoneStatus}
import vinyldns.api.ResultHelpers
import vinyldns.core.TestRecordSetData.{aaaa, ds}
import vinyldns.core.domain.record.{RecordSet, RecordType}
class DnsConversionsIntegrationSpec extends AnyWordSpec with Matchers with ResultHelpers {
private val zoneName = "example.com."
private val testZone = Zone(
zoneName,
"test@test.com",
ZoneStatus.Active,
connection =
Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "127.0.0.1:19001")),
transferConnection =
Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "127.0.0.1:19001"))
)
"Interacting with the DNS backend" should {
"remove the tsig key value during an update" in {
val testRecord = aaaa.copy(zoneId = testZone.id)
val conn = DnsConnection(testZone.connection.get)
val result: DnsResponse =
rightResultOf(conn.addRecord(RecordSetChangeGenerator.forAdd(testRecord, testZone)).value)
result shouldBe a[NoError]
val resultingMessage = result.asInstanceOf[NoError].message
resultingMessage.getSectionArray(DNS.Section.ADDITIONAL) shouldBe empty
val resultingMessageString = resultingMessage.toString
resultingMessageString should not contain "TSIG"
val queryResult: List[RecordSet] =
rightResultOf(conn.resolve(testRecord.name, testZone.name, RecordType.AAAA).value)
val recordOut = queryResult.head
recordOut.records should contain theSameElementsAs testRecord.records
recordOut.name shouldBe testRecord.name
recordOut.ttl shouldBe testRecord.ttl
recordOut.typ shouldBe testRecord.typ
}
"Successfully add and remove DS record type" in {
val testRecord = ds.copy(zoneId = testZone.id)
val conn = DnsConnection(testZone.connection.get)
val result: DnsResponse =
rightResultOf(conn.addRecord(RecordSetChangeGenerator.forAdd(testRecord, testZone)).value)
result shouldBe a[NoError]
val queryResult: List[RecordSet] =
rightResultOf(conn.resolve(testRecord.name, testZone.name, RecordType.DS).value)
val recordOut = queryResult.head
recordOut.records should contain theSameElementsAs testRecord.records
recordOut.name shouldBe testRecord.name
recordOut.ttl shouldBe testRecord.ttl
recordOut.typ shouldBe testRecord.typ
// deleting the record just added
val deleteResult: DnsResponse =
rightResultOf(
conn.deleteRecord(RecordSetChangeGenerator.forAdd(testRecord, testZone)).value
)
deleteResult shouldBe a[NoError]
val deleteQuery: List[RecordSet] =
rightResultOf(conn.resolve(testRecord.name, testZone.name, RecordType.DS).value)
deleteQuery shouldBe List.empty
}
}
}

View File

@ -24,16 +24,15 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.PatienceConfiguration
import org.scalatestplus.mockito.MockitoSugar import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.time.{Seconds, Span} import org.scalatest.time.{Seconds, Span}
import vinyldns.api.Interfaces._
import vinyldns.api._ import vinyldns.api._
import vinyldns.api.domain.access.AccessValidations import vinyldns.api.domain.access.AccessValidations
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.zone._ import vinyldns.api.domain.zone._
import vinyldns.api.engine.TestMessageQueue import vinyldns.api.engine.TestMessageQueue
import vinyldns.core.TestMembershipData._ import vinyldns.core.TestMembershipData._
import vinyldns.core.TestZoneData.testConnection import vinyldns.core.TestZoneData.testConnection
import vinyldns.core.domain.{Fqdn, HighValueDomainError} import vinyldns.core.domain.{Fqdn, HighValueDomainError}
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.membership.{Group, GroupRepository, User, UserRepository} import vinyldns.core.domain.membership.{Group, GroupRepository, User, UserRepository}
import vinyldns.core.domain.record.RecordType._ import vinyldns.core.domain.record.RecordType._
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
@ -224,13 +223,8 @@ class RecordSetServiceIntegrationSpec
ownerGroupId = Some(sharedGroup.id) ownerGroupId = Some(sharedGroup.id)
) )
private val zoneConnection = private val mockBackendResolver = mock[BackendResolver]
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1") private val mockBackend = mock[Backend]
private val configuredConnections =
ConfiguredDnsConnections(zoneConnection, zoneConnection, List())
private val mockDnsConnection = mock[DnsConnection]
def setup(): Unit = { def setup(): Unit = {
recordSetRepo = recordSetRepo =
@ -266,8 +260,7 @@ class RecordSetServiceIntegrationSpec
mock[UserRepository], mock[UserRepository],
TestMessageQueue, TestMessageQueue,
new AccessValidations(), new AccessValidations(),
(_, _) => mockDnsConnection, mockBackendResolver,
configuredConnections,
false false
) )
} }
@ -383,8 +376,8 @@ class RecordSetServiceIntegrationSpec
"fail to add relative record if apex record with same name already exists" in { "fail to add relative record if apex record with same name already exists" in {
val newRecord = apexTestRecordNameConflict.copy(name = "zone-test-name-conflicts") val newRecord = apexTestRecordNameConflict.copy(name = "zone-test-name-conflicts")
doReturn(IO(List(newRecord)).toResult) doReturn(IO(List(newRecord)))
.when(mockDnsConnection) .when(mockBackend)
.resolve( .resolve(
zoneTestNameConflicts.name, zoneTestNameConflicts.name,
zoneTestNameConflicts.name, zoneTestNameConflicts.name,
@ -405,8 +398,8 @@ class RecordSetServiceIntegrationSpec
"fail to add apex record if relative record with same name already exists" in { "fail to add apex record if relative record with same name already exists" in {
val newRecord = subTestRecordNameConflict.copy(name = "relative-name-conflict.") val newRecord = subTestRecordNameConflict.copy(name = "relative-name-conflict.")
doReturn(IO(List(newRecord)).toResult) doReturn(IO(List(newRecord)))
.when(mockDnsConnection) .when(mockBackend)
.resolve(newRecord.name, zoneTestNameConflicts.name, newRecord.typ) .resolve(newRecord.name, zoneTestNameConflicts.name, newRecord.typ)
val result = val result =

View File

@ -1,50 +0,0 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.domain.zone
import cats.scalatest.EitherMatchers
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import vinyldns.api.VinylDNSConfig
import vinyldns.core.domain.zone.ConfiguredDnsConnections
import vinyldns.core.health.HealthCheck.HealthCheckError
class ZoneConnectionValidatorIntegrationSpec extends AnyWordSpec with Matchers with EitherMatchers {
"ZoneConnectionValidatorIntegrationSpec" should {
"have a valid health check if we can connect to DNS backend" in {
val check = new ZoneConnectionValidator(VinylDNSConfig.configuredDnsConnections)
.healthCheck(10000)
.unsafeRunSync()
check should beRight(())
}
"respond with a failure if health check fails" in {
val connections = VinylDNSConfig.configuredDnsConnections
val badConn = connections.defaultZoneConnection.copy(primaryServer = "localhost:1234")
val toTest = ConfiguredDnsConnections(badConn, badConn, List())
val result =
new ZoneConnectionValidator(toTest)
.healthCheck(10000)
.unsafeRunSync()
result should beLeft(
HealthCheckError(
"vinyldns.api.domain.zone.ZoneConnectionValidator health " +
"check failed with msg='Connection refused (Connection refused)'"
)
)
}
}
}

View File

@ -16,8 +16,10 @@
package vinyldns.api.domain.zone package vinyldns.api.domain.zone
import cats.data.NonEmptyList
import cats.effect._ import cats.effect._
import org.joda.time.DateTime import org.joda.time.DateTime
import org.mockito.Mockito.doReturn
import org.scalatest._ import org.scalatest._
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
@ -33,6 +35,7 @@ import vinyldns.core.TestMembershipData.{okAuth, okUser}
import vinyldns.core.TestZoneData.okZone import vinyldns.core.TestZoneData.okZone
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.backend.BackendResolver
import vinyldns.core.domain.membership.{GroupRepository, UserRepository} import vinyldns.core.domain.membership.{GroupRepository, UserRepository}
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone._ import vinyldns.core.domain.zone._
@ -93,6 +96,8 @@ class ZoneServiceIntegrationSpec
private val changeSetNS = ChangeSet(RecordSetChangeGenerator.forAdd(testRecordNS, okZone)) private val changeSetNS = ChangeSet(RecordSetChangeGenerator.forAdd(testRecordNS, okZone))
private val changeSetA = ChangeSet(RecordSetChangeGenerator.forAdd(testRecordA, okZone)) private val changeSetA = ChangeSet(RecordSetChangeGenerator.forAdd(testRecordA, okZone))
private val mockBackendResolver = mock[BackendResolver]
def clearRecordSetRepo(): Unit = def clearRecordSetRepo(): Unit =
DB.localTx { s => DB.localTx { s =>
s.executeUpdate("DELETE FROM recordset") s.executeUpdate("DELETE FROM recordset")
@ -113,6 +118,8 @@ class ZoneServiceIntegrationSpec
waitForSuccess(recordSetRepo.apply(changeSetNS)) waitForSuccess(recordSetRepo.apply(changeSetNS))
waitForSuccess(recordSetRepo.apply(changeSetA)) waitForSuccess(recordSetRepo.apply(changeSetA))
doReturn(NonEmptyList.one("func-test-backend")).when(mockBackendResolver).ids
testZoneService = new ZoneService( testZoneService = new ZoneService(
zoneRepo, zoneRepo,
mock[GroupRepository], mock[GroupRepository],
@ -121,7 +128,8 @@ class ZoneServiceIntegrationSpec
mock[ZoneConnectionValidator], mock[ZoneConnectionValidator],
TestMessageQueue, TestMessageQueue,
new ZoneValidations(1000), new ZoneValidations(1000),
new AccessValidations() new AccessValidations(),
mockBackendResolver
) )
} }

View File

@ -16,46 +16,84 @@
package vinyldns.api.domain.zone package vinyldns.api.domain.zone
import cats.effect.{ContextShift, IO}
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.xbill.DNS.ZoneTransferException import org.xbill.DNS.ZoneTransferException
import vinyldns.api.VinylDNSConfig
import vinyldns.api.backend.dns.DnsBackend
import vinyldns.core.domain.backend.{BackendConfigs, BackendResolver}
import vinyldns.core.domain.zone.{Zone, ZoneConnection} import vinyldns.core.domain.zone.{Zone, ZoneConnection}
import scala.concurrent.ExecutionContext
class ZoneViewLoaderIntegrationSpec extends AnyWordSpec with Matchers { class ZoneViewLoaderIntegrationSpec extends AnyWordSpec with Matchers {
private implicit val ec: ExecutionContext = scala.concurrent.ExecutionContext.global
private implicit val cs: ContextShift[IO] = IO.contextShift(ec)
private val backendResolver =
BackendResolver
.apply(BackendConfigs.load(VinylDNSConfig.apiBackend).unsafeRunSync())
.unsafeRunSync()
"ZoneViewLoader" should { "ZoneViewLoader" should {
"return a ZoneView upon success" in { "return a ZoneView upon success" in {
DnsZoneViewLoader(Zone("vinyldns.", "test@test.com")) val zone = Zone("vinyldns.", "test@test.com")
DnsZoneViewLoader(zone, backendResolver.resolve(zone))
.load() .load()
.unsafeRunSync() shouldBe a[ZoneView] .unsafeRunSync() shouldBe a[ZoneView]
} }
"return a failure if the transfer connection is bad" in { "return a failure if the transfer connection is bad" in {
assertThrows[IllegalArgumentException]( assertThrows[IllegalArgumentException] {
DnsZoneViewLoader( val zone = Zone(
Zone("vinyldns.", "bad@transfer.connection") "vinyldns.",
.copy( "bad@transfer.connection",
connection = Some(
ZoneConnection(
"vinyldns.",
"vinyldns.",
"nzisn+4G2ldMn0q1CV3vsg==",
"127.0.0.1:19001"
)
),
transferConnection = transferConnection =
Some(ZoneConnection("invalid-connection.", "bad-key", "invalid-key", "10.1.1.1")) Some(ZoneConnection("invalid-connection.", "bad-key", "invalid-key", "10.1.1.1"))
) )
).load() val backend = backendResolver.resolve(zone).asInstanceOf[DnsBackend]
println(s"${backend.id}, ${backend.xfrInfo}, ${backend.resolver.getAddress}")
DnsZoneViewLoader(zone, backendResolver.resolve(zone))
.load()
.unsafeRunSync() .unsafeRunSync()
) }
} }
"return a failure if the zone doesn't exist in the DNS backend" in { "return a failure if the zone doesn't exist in the DNS backend" in {
assertThrows[ZoneTransferException]( assertThrows[ZoneTransferException] {
DnsZoneViewLoader(Zone("non-existent-zone", "bad@zone.test")) val zone = Zone("non-existent-zone", "bad@zone.test")
DnsZoneViewLoader(zone, backendResolver.resolve(zone))
.load() .load()
.unsafeRunSync() .unsafeRunSync()
) }
} }
"return a failure if the zone is larger than the max zone size" in { "return a failure if the zone is larger than the max zone size" in {
assertThrows[ZoneTooLargeError]( assertThrows[ZoneTooLargeError] {
DnsZoneViewLoader(Zone("vinyldns.", "test@test.com"), DnsZoneViewLoader.dnsZoneTransfer, 1) val zone = Zone(
"vinyldns.",
"test@test.com",
connection = Some(
ZoneConnection(
"vinyldns.",
"vinyldns.",
"nzisn+4G2ldMn0q1CV3vsg==",
"127.0.0.1:19001"
)
)
)
DnsZoneViewLoader(zone, backendResolver.resolve(zone), 1)
.load() .load()
.unsafeRunSync() .unsafeRunSync()
) }
} }
} }
} }

View File

@ -5,6 +5,59 @@
################################################################################################################ ################################################################################################################
vinyldns { vinyldns {
# configured backend providers
backend {
# Use "default" when dns backend legacy = true
# otherwise, use the id of one of the connections in any of your backends
default-backend-id = "default"
# this is where we can save additional backends
backend-providers = [
{
class-name = "vinyldns.api.backend.dns.DnsBackendProviderLoader"
settings = {
legacy = true # set this to true to attempt to load legacy config YAML
backends = []
# if not legacy then this...
# legacy = false
# backends = [
# {
# id = "default"
# zone-connection = {
# name = "vinyldns."
# keyName = "vinyldns."
# key = "nzisn+4G2ldMn0q1CV3vsg=="
# primaryServer = "127.0.0.1:19001"
# }
# transfer-connection = {
# name = "vinyldns."
# keyName = "vinyldns."
# key = "nzisn+4G2ldMn0q1CV3vsg=="
# primaryServer = "127.0.0.1:19001"
# }
# },
# {
# id = "func-test-backend"
# zone-connection = {
# name = "vinyldns."
# keyName = "vinyldns."
# key = "nzisn+4G2ldMn0q1CV3vsg=="
# primaryServer = "127.0.0.1:19001"
# }
# transfer-connection = {
# name = "vinyldns."
# keyName = "vinyldns."
# key = "nzisn+4G2ldMn0q1CV3vsg=="
# primaryServer = "127.0.0.1:19001"
# }
# }
#]
}
}
]
}
# if we should start up polling for change requests, set this to false for the inactive cluster # if we should start up polling for change requests, set this to false for the inactive cluster
processing-disabled = false processing-disabled = false

View File

@ -30,7 +30,6 @@ import vinyldns.api.crypto.Crypto
import vinyldns.api.domain.access.AccessValidations import vinyldns.api.domain.access.AccessValidations
import vinyldns.api.domain.auth.MembershipAuthPrincipalProvider import vinyldns.api.domain.auth.MembershipAuthPrincipalProvider
import vinyldns.api.domain.batch.{BatchChangeConverter, BatchChangeService, BatchChangeValidations} import vinyldns.api.domain.batch.{BatchChangeConverter, BatchChangeService, BatchChangeValidations}
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.membership._ import vinyldns.api.domain.membership._
import vinyldns.api.domain.record.RecordSetService import vinyldns.api.domain.record.RecordSetService
import vinyldns.api.domain.zone._ import vinyldns.api.domain.zone._
@ -38,6 +37,7 @@ import vinyldns.api.metrics.APIMetrics
import vinyldns.api.repository.{ApiDataAccessor, ApiDataAccessorProvider, TestDataLoader} import vinyldns.api.repository.{ApiDataAccessor, ApiDataAccessorProvider, TestDataLoader}
import vinyldns.api.route.VinylDNSService import vinyldns.api.route.VinylDNSService
import vinyldns.core.VinylDNSMetrics import vinyldns.core.VinylDNSMetrics
import vinyldns.core.domain.backend.{BackendConfigs, BackendResolver}
import vinyldns.core.health.HealthService import vinyldns.core.health.HealthService
import vinyldns.core.queue.{MessageCount, MessageQueueLoader} import vinyldns.core.queue.{MessageCount, MessageQueueLoader}
import vinyldns.core.repository.DataStoreLoader import vinyldns.core.repository.DataStoreLoader
@ -74,7 +74,8 @@ object Boot extends App {
loaderResponse <- DataStoreLoader loaderResponse <- DataStoreLoader
.loadAll[ApiDataAccessor](repoConfigs, crypto, ApiDataAccessorProvider) .loadAll[ApiDataAccessor](repoConfigs, crypto, ApiDataAccessorProvider)
repositories = loaderResponse.accessor repositories = loaderResponse.accessor
connections = VinylDNSConfig.configuredDnsConnections backendConfigs <- BackendConfigs.load(VinylDNSConfig.apiBackend)
backendResolver <- BackendResolver.apply(backendConfigs)
_ <- TestDataLoader _ <- TestDataLoader
.loadTestData( .loadTestData(
repositories.userRepository, repositories.userRepository,
@ -109,7 +110,7 @@ object Boot extends App {
repositories.recordChangeRepository, repositories.recordChangeRepository,
repositories.batchChangeRepository, repositories.batchChangeRepository,
notifiers, notifiers,
connections backendResolver
) )
.start .start
} yield { } yield {
@ -123,15 +124,13 @@ object Boot extends App {
) )
val membershipService = MembershipService(repositories) val membershipService = MembershipService(repositories)
val connectionValidator = val connectionValidator =
new ZoneConnectionValidator(connections) new ZoneConnectionValidator(backendResolver)
val recordSetService = val recordSetService =
RecordSetService( RecordSetService(
repositories, repositories,
messageQueue, messageQueue,
recordAccessValidations, recordAccessValidations,
(zone, connections) => backendResolver,
DnsConnection(ZoneConnectionValidator.getZoneConnection(zone, connections)),
connections,
VinylDNSConfig.validateRecordLookupAgainstDnsBackend VinylDNSConfig.validateRecordLookupAgainstDnsBackend
) )
val zoneService = ZoneService( val zoneService = ZoneService(
@ -139,10 +138,11 @@ object Boot extends App {
connectionValidator, connectionValidator,
messageQueue, messageQueue,
zoneValidations, zoneValidations,
recordAccessValidations recordAccessValidations,
backendResolver
) )
val healthService = new HealthService( val healthService = new HealthService(
messageQueue.healthCheck :: connectionValidator.healthCheck(healthCheckTimeout) :: messageQueue.healthCheck :: backendResolver.healthCheck(healthCheckTimeout) ::
loaderResponse.healthChecks loaderResponse.healthChecks
) )
val batchChangeConverter = val batchChangeConverter =

View File

@ -35,7 +35,7 @@ import vinyldns.core.domain.record.RecordType
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.util.matching.Regex import scala.util.matching.Regex
import vinyldns.core.domain.zone.{ConfiguredDnsConnections, DnsBackend, ZoneConnection} import vinyldns.core.domain.zone.{ConfiguredDnsConnections, LegacyDnsBackend, ZoneConnection}
import vinyldns.core.queue.MessageQueueConfig import vinyldns.core.queue.MessageQueueConfig
import vinyldns.core.repository.DataStoreConfig import vinyldns.core.repository.DataStoreConfig
import vinyldns.core.notifier.NotifierConfig import vinyldns.core.notifier.NotifierConfig
@ -47,6 +47,9 @@ object VinylDNSConfig {
lazy val config: Config = ConfigFactory.load() lazy val config: Config = ConfigFactory.load()
lazy val vinyldnsConfig: Config = config.getConfig("vinyldns") lazy val vinyldnsConfig: Config = config.getConfig("vinyldns")
lazy val apiBackend: Config =
vinyldnsConfig.getConfig("backend")
lazy val dataStoreConfigs: IO[List[DataStoreConfig]] = lazy val dataStoreConfigs: IO[List[DataStoreConfig]] =
vinyldnsConfig vinyldnsConfig
.getStringList("data-stores") .getStringList("data-stores")
@ -119,7 +122,7 @@ object VinylDNSConfig {
.getConfigList("backends") .getConfigList("backends")
.asScala .asScala
.map { .map {
ConfigSource.fromConfig(_).loadOrThrow[DnsBackend] ConfigSource.fromConfig(_).loadOrThrow[LegacyDnsBackend]
} }
.toList .toList
.map(_.encrypted(Crypto.instance)) .map(_.encrypted(Crypto.instance))

View File

@ -20,14 +20,13 @@ import cats.effect.{ContextShift, IO, Timer}
import fs2._ import fs2._
import fs2.concurrent.SignallingRef import fs2.concurrent.SignallingRef
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.zone.ZoneConnectionValidator
import vinyldns.api.engine.{ import vinyldns.api.engine.{
BatchChangeHandler, BatchChangeHandler,
RecordSetChangeHandler, RecordSetChangeHandler,
ZoneChangeHandler, ZoneChangeHandler,
ZoneSyncHandler ZoneSyncHandler
} }
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.batch.{BatchChange, BatchChangeCommand, BatchChangeRepository} import vinyldns.core.domain.batch.{BatchChange, BatchChangeCommand, BatchChangeRepository}
import vinyldns.core.domain.record.{RecordChangeRepository, RecordSetChange, RecordSetRepository} import vinyldns.core.domain.record.{RecordChangeRepository, RecordSetChange, RecordSetRepository}
import vinyldns.core.domain.zone._ import vinyldns.core.domain.zone._
@ -51,14 +50,14 @@ object CommandHandler {
def mainFlow( def mainFlow(
zoneChangeHandler: ZoneChange => IO[ZoneChange], zoneChangeHandler: ZoneChange => IO[ZoneChange],
recordChangeHandler: (DnsConnection, RecordSetChange) => IO[RecordSetChange], recordChangeHandler: (Backend, RecordSetChange) => IO[RecordSetChange],
zoneSyncHandler: ZoneChange => IO[ZoneChange], zoneSyncHandler: ZoneChange => IO[ZoneChange],
batchChangeHandler: BatchChangeCommand => IO[Option[BatchChange]], batchChangeHandler: BatchChangeCommand => IO[Option[BatchChange]],
mq: MessageQueue, mq: MessageQueue,
count: MessageCount, count: MessageCount,
pollingInterval: FiniteDuration, pollingInterval: FiniteDuration,
pauseSignal: SignallingRef[IO, Boolean], pauseSignal: SignallingRef[IO, Boolean],
connections: ConfiguredDnsConnections, backendResolver: BackendResolver,
maxOpen: Int = 4 maxOpen: Int = 4
)(implicit timer: Timer[IO]): Stream[IO, Unit] = { )(implicit timer: Timer[IO]): Stream[IO, Unit] = {
@ -74,7 +73,7 @@ object CommandHandler {
recordChangeHandler, recordChangeHandler,
zoneSyncHandler, zoneSyncHandler,
batchChangeHandler, batchChangeHandler,
connections backendResolver
) )
// Delete messages from message queue when complete // Delete messages from message queue when complete
@ -146,10 +145,10 @@ object CommandHandler {
/* Actually processes a change request */ /* Actually processes a change request */
def processChangeRequests( def processChangeRequests(
zoneChangeProcessor: ZoneChange => IO[ZoneChange], zoneChangeProcessor: ZoneChange => IO[ZoneChange],
recordChangeProcessor: (DnsConnection, RecordSetChange) => IO[RecordSetChange], recordChangeProcessor: (Backend, RecordSetChange) => IO[RecordSetChange],
zoneSyncProcessor: ZoneChange => IO[ZoneChange], zoneSyncProcessor: ZoneChange => IO[ZoneChange],
batchChangeProcessor: BatchChangeCommand => IO[Option[BatchChange]], batchChangeProcessor: BatchChangeCommand => IO[Option[BatchChange]],
connections: ConfiguredDnsConnections backendResolver: BackendResolver
): Pipe[IO, CommandMessage, MessageOutcome] = ): Pipe[IO, CommandMessage, MessageOutcome] =
_.evalMap[IO, MessageOutcome] { message => _.evalMap[IO, MessageOutcome] { message =>
message.command match { message.command match {
@ -161,9 +160,7 @@ object CommandHandler {
outcomeOf(message)(zoneChangeProcessor(zoneChange)) outcomeOf(message)(zoneChangeProcessor(zoneChange))
case rcr: RecordSetChange => case rcr: RecordSetChange =>
val dnsConn = outcomeOf(message)(recordChangeProcessor(backendResolver.resolve(rcr.zone), rcr))
DnsConnection(ZoneConnectionValidator.getZoneConnection(rcr.zone, connections))
outcomeOf(message)(recordChangeProcessor(dnsConn, rcr))
case bcc: BatchChangeCommand => case bcc: BatchChangeCommand =>
outcomeOf(message)(batchChangeProcessor(bcc)) outcomeOf(message)(batchChangeProcessor(bcc))
@ -207,7 +204,7 @@ object CommandHandler {
recordChangeRepo: RecordChangeRepository, recordChangeRepo: RecordChangeRepository,
batchChangeRepo: BatchChangeRepository, batchChangeRepo: BatchChangeRepository,
notifiers: AllNotifiers, notifiers: AllNotifiers,
connections: ConfiguredDnsConnections backendResolver: BackendResolver
)(implicit timer: Timer[IO]): IO[Unit] = { )(implicit timer: Timer[IO]): IO[Unit] = {
// Handlers for each type of change request // Handlers for each type of change request
val zoneChangeHandler = val zoneChangeHandler =
@ -215,7 +212,7 @@ object CommandHandler {
val recordChangeHandler = val recordChangeHandler =
RecordSetChangeHandler(recordSetRepo, recordChangeRepo, batchChangeRepo) RecordSetChangeHandler(recordSetRepo, recordChangeRepo, batchChangeRepo)
val zoneSyncHandler = val zoneSyncHandler =
ZoneSyncHandler(recordSetRepo, recordChangeRepo, zoneChangeRepo, zoneRepo) ZoneSyncHandler(recordSetRepo, recordChangeRepo, zoneChangeRepo, zoneRepo, backendResolver)
val batchChangeHandler = val batchChangeHandler =
BatchChangeHandler(batchChangeRepo, notifiers) BatchChangeHandler(batchChangeRepo, notifiers)
@ -229,7 +226,7 @@ object CommandHandler {
msgsPerPoll, msgsPerPoll,
pollingInterval, pollingInterval,
processingSignal, processingSignal,
connections backendResolver
) )
.compile .compile
.drain .drain

View File

@ -14,24 +14,28 @@
* limitations under the License. * limitations under the License.
*/ */
package vinyldns.api.domain.dns package vinyldns.api.backend.dns
import java.net.SocketAddress
import cats.effect._ import cats.effect._
import cats.syntax.all._ import cats.syntax.all._
import org.slf4j.{Logger, LoggerFactory} import org.slf4j.{Logger, LoggerFactory}
import org.xbill.DNS import org.xbill.DNS
import vinyldns.api.Interfaces.{result, _} import vinyldns.api.domain.zone.ZoneTooLargeError
import vinyldns.api.crypto.Crypto import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.backend.{Backend, BackendResponse}
import vinyldns.core.domain.record.RecordType.RecordType import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.record.{RecordSet, RecordSetChange, RecordSetChangeType} import vinyldns.core.domain.record.{RecordSet, RecordSetChange, RecordSetChangeType, RecordType}
import vinyldns.core.domain.zone.{Zone, ZoneConnection} import vinyldns.core.domain.zone.{Zone, ZoneConnection}
import scala.collection.JavaConverters._
object DnsProtocol { object DnsProtocol {
sealed trait DnsRequest sealed trait DnsRequest
final case class Apply(change: RecordSetChange) extends DnsRequest final case class Apply(change: RecordSetChange) extends DnsRequest
// TODO: Remove origin once we change to using Zone Activation
case class Resolve(name: String, zone: Zone, typ: RecordType) case class Resolve(name: String, zone: Zone, typ: RecordType)
case class UpdateConnection(zoneConnection: ZoneConnection) case class UpdateConnection(zoneConnection: ZoneConnection)
@ -84,25 +88,70 @@ class DnsQuery(val lookup: DNS.Lookup, val zoneName: DNS.Name) {
def error: String = lookup.getErrorString def error: String = lookup.getErrorString
} }
class DnsConnection(val resolver: DNS.SimpleResolver) extends DnsConversions { final case class TransferInfo(address: SocketAddress, tsig: DNS.TSIG)
class DnsBackend(val id: String, val resolver: DNS.SimpleResolver, val xfrInfo: TransferInfo)
extends Backend
with DnsConversions {
import DnsProtocol._ import DnsProtocol._
val logger: Logger = LoggerFactory.getLogger(classOf[DnsConnection]) val logger: Logger = LoggerFactory.getLogger(classOf[DnsBackend])
def applyChange(change: RecordSetChange): Result[DnsResponse] = change.changeType match { def applyChange(change: RecordSetChange): IO[BackendResponse] = {
change.changeType match {
case RecordSetChangeType.Create => addRecord(change) case RecordSetChangeType.Create => addRecord(change)
case RecordSetChangeType.Update => updateRecord(change) case RecordSetChangeType.Update => updateRecord(change)
case RecordSetChangeType.Delete => deleteRecord(change) case RecordSetChangeType.Delete => deleteRecord(change)
} }
}.attempt.flatMap {
case Left(DnsProtocol.Refused(msg)) => IO(BackendResponse.Retry(msg))
case Right(DnsProtocol.NoError(msg)) => IO(BackendResponse.NoError(msg.toString))
case Left(otherFailure) => IO.raiseError(otherFailure)
}
def resolve(name: String, zoneName: String, typ: RecordType): Result[List[RecordSet]] = def resolve(name: String, zoneName: String, typ: RecordType): IO[List[RecordSet]] =
IO { IO.fromEither {
for { for {
query <- toQuery(name, zoneName, typ) query <- toQuery(name, zoneName, typ)
records <- runQuery(query) records <- runQuery(query)
} yield records } yield records
}.toResult }
def loadZone(zone: Zone, maxZoneSize: Int): IO[List[RecordSet]] = {
val dnsZoneName = zoneDnsName(zone.name)
val zti = DNS.ZoneTransferIn.newAXFR(dnsZoneName, xfrInfo.address, xfrInfo.tsig)
for {
zoneXfr <- IO {
zti.run()
zti.getAXFR.asScala.map(_.asInstanceOf[DNS.Record]).toList.distinct
}
rawDnsRecords = zoneXfr.filter(
record => fromDnsRecordType(record.getType) != RecordType.UNKNOWN
)
_ <- if (rawDnsRecords.length > maxZoneSize) {
IO.raiseError(
ZoneTooLargeError(
s"Zone too large ${zone.name}, ${rawDnsRecords.length} records exceeded max $maxZoneSize"
)
)
} else {
IO.pure(Unit)
}
dnsZoneName <- IO(zoneDnsName(zone.name))
recordSets <- IO(rawDnsRecords.map(toRecordSet(_, dnsZoneName, zone.id)))
} yield recordSets
}
/**
* Indicates if the zone is present in the backend
*
* @param zone The zone to check if exists
* @return true if it exists; false otherwise
*/
def zoneExists(zone: Zone): IO[Boolean] =
resolve(zone.name, zone.name, RecordType.SOA).map(_.nonEmpty)
private[dns] def toQuery( private[dns] def toQuery(
name: String, name: String,
@ -125,7 +174,7 @@ class DnsConnection(val resolver: DNS.SimpleResolver) extends DnsConversions {
case _ => Right(change) case _ => Right(change)
} }
private[dns] def addRecord(change: RecordSetChange): Result[DnsResponse] = result { private[dns] def addRecord(change: RecordSetChange): IO[DnsResponse] = IO.fromEither {
for { for {
change <- recordsArePresent(change) change <- recordsArePresent(change)
addRecord <- toDnsRRset(change.recordSet, change.zone.name) addRecord <- toDnsRRset(change.recordSet, change.zone.name)
@ -134,7 +183,7 @@ class DnsConnection(val resolver: DNS.SimpleResolver) extends DnsConversions {
} yield response } yield response
} }
private[dns] def updateRecord(change: RecordSetChange): Result[DnsResponse] = result { private[dns] def updateRecord(change: RecordSetChange): IO[DnsResponse] = IO.fromEither {
for { for {
change <- recordsArePresent(change) change <- recordsArePresent(change)
dnsRecord <- toDnsRRset(change.recordSet, change.zone.name) dnsRecord <- toDnsRRset(change.recordSet, change.zone.name)
@ -144,7 +193,7 @@ class DnsConnection(val resolver: DNS.SimpleResolver) extends DnsConversions {
} yield response } yield response
} }
private[dns] def deleteRecord(change: RecordSetChange): Result[DnsResponse] = result { private[dns] def deleteRecord(change: RecordSetChange): IO[DnsResponse] = IO.fromEither {
for { for {
change <- recordsArePresent(change) change <- recordsArePresent(change)
dnsRecord <- toDnsRRset(change.recordSet, change.zone.name) dnsRecord <- toDnsRRset(change.recordSet, change.zone.name)
@ -198,21 +247,41 @@ class DnsConnection(val resolver: DNS.SimpleResolver) extends DnsConversions {
} }
} }
object DnsConnection { object DnsBackend {
def apply(conn: ZoneConnection): DnsConnection = new DnsConnection(createResolver(conn)) def apply(
id: String,
conn: ZoneConnection,
xfrConn: Option[ZoneConnection],
crypto: CryptoAlgebra
): DnsBackend = {
val tsig = createTsig(conn, crypto)
val resolver = createResolver(conn, tsig)
def createResolver(conn: ZoneConnection): DNS.SimpleResolver = { val xfrInfo = xfrConn
// IMPORTANT! Make sure we decrypt the zone connection before creating the resolver .map { xc =>
val decryptedConnection = conn.decrypted(Crypto.instance) val xt = createTsig(xc, crypto)
val (host, port) = parseHostAndPort(decryptedConnection.primaryServer) val xr = createResolver(xc, xt)
TransferInfo(xr.getAddress, xt)
}
.getOrElse(TransferInfo(resolver.getAddress, tsig))
new DnsBackend(id, resolver, xfrInfo)
}
def createResolver(conn: ZoneConnection, tsig: DNS.TSIG): DNS.SimpleResolver = {
val (host, port) = parseHostAndPort(conn.primaryServer)
val resolver = new DNS.SimpleResolver(host) val resolver = new DNS.SimpleResolver(host)
resolver.setPort(port) resolver.setPort(port)
resolver.setTSIGKey(new DNS.TSIG(decryptedConnection.keyName, decryptedConnection.key)) resolver.setTSIGKey(tsig)
resolver resolver
} }
def createTsig(conn: ZoneConnection, crypto: CryptoAlgebra): DNS.TSIG = {
val decryptedConnection = conn.decrypted(crypto)
new DNS.TSIG(decryptedConnection.keyName, decryptedConnection.key)
}
def parseHostAndPort(primaryServer: String): (String, Int) = { def parseHostAndPort(primaryServer: String): (String, Int) = {
val parts = primaryServer.trim().split(':') val parts = primaryServer.trim().split(':')
if (parts.length < 2) if (parts.length < 2)

View File

@ -0,0 +1,29 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.backend.dns
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.zone.ZoneConnection
final case class DnsBackendConfig(
id: String,
zoneConnection: ZoneConnection,
transferConnection: Option[ZoneConnection]
) {
def toDnsConnection(crypto: CryptoAlgebra): DnsBackend =
DnsBackend.apply(id, zoneConnection, transferConnection, crypto)
}

View File

@ -0,0 +1,59 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.backend.dns
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.backend.{Backend, BackendProvider}
import vinyldns.core.domain.zone.Zone
class DnsBackendProvider(connections: List[DnsBackend], crypto: CryptoAlgebra)
extends BackendProvider {
private val connMap: Map[String, DnsBackend] =
connections.map { c =>
c.id -> c
}.toMap
/**
* Given a zone, returns a connection to the zone, returns None if cannot connect
*
* @param zone The zone to attempt to connect to
* @return A backend that is usable, or None if it could not connect
*/
def connect(zone: Zone): Option[Backend] =
// Use the connection info on the zone if present
zone.connection
.map { conn =>
DnsBackend.apply("unknown", conn, zone.transferConnection, crypto)
}
.orElse {
zone.backendId.flatMap(connectById)
}
/**
* Given a backend id, looks up the backend for this provider if it exists
*
* @return A backend that is usable, or None if could not connect
*/
def connectById(backendId: String): Option[Backend] =
connMap.get(backendId)
/**
* @return The backend ids loaded with this provider
*/
def ids: List[String] = connMap.keys.toList
}

View File

@ -0,0 +1,32 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.backend.dns
import cats.effect.{Blocker, ContextShift, IO}
import com.typesafe.config.Config
import pureconfig.ConfigSource
import pureconfig.generic.auto._
import pureconfig.module.catseffect.syntax.CatsEffectConfigSource
final case class DnsBackendProviderConfig(legacy: Boolean, backends: List[DnsBackendConfig])
object DnsBackendProviderConfig {
def load(config: Config)(implicit cs: ContextShift[IO]): IO[DnsBackendProviderConfig] =
Blocker[IO].use(
ConfigSource.fromConfig(config).loadF[IO, DnsBackendProviderConfig](_)
)
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.api.backend.dns
import cats.effect.{ContextShift, IO}
import vinyldns.api.VinylDNSConfig
import vinyldns.api.crypto.Crypto
import vinyldns.core.domain.backend.{BackendProvider, BackendProviderConfig, BackendProviderLoader}
class DnsBackendProviderLoader extends BackendProviderLoader {
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
/**
* Loads a backend based on the provided config so that it is ready to use
* This is internally used typically during startup
*
* @param config The BackendConfig, has settings that are specific to this backend
* @return A ready-to-use Backend instance, or does an IO.raiseError if something bad occurred.
*/
def load(config: BackendProviderConfig): IO[BackendProvider] =
// if legacy = true, load from the old configured dns connections
// otherwise, load new stuff
DnsBackendProviderConfig.load(config.settings).map { bec =>
if (bec.legacy) {
// legacy adds a backend id named "default" with the default configuration
// and loads the backend connections from the legacy YAML config
val conns = VinylDNSConfig.configuredDnsConnections.dnsBackends.map { be =>
DnsBackend
.apply(be.id, be.zoneConnection, Some(be.transferConnection), Crypto.instance)
}
val defaultConn =
DnsBackend.apply(
"default",
VinylDNSConfig.configuredDnsConnections.defaultZoneConnection,
Some(VinylDNSConfig.configuredDnsConnections.defaultTransferConnection),
Crypto.instance
)
new DnsBackendProvider(defaultConn :: conns, Crypto.instance)
} else {
// Assumes the "new" YAML config
new DnsBackendProvider(
bec.backends.map(_.toDnsConnection(Crypto.instance)),
Crypto.instance
)
}
}
}

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
package vinyldns.api.domain.dns package vinyldns.api.backend.dns
import java.net.InetAddress import java.net.InetAddress
@ -22,10 +22,10 @@ import cats.syntax.either._
import org.joda.time.DateTime import org.joda.time.DateTime
import org.xbill.DNS import org.xbill.DNS
import scodec.bits.ByteVector import scodec.bits.ByteVector
import vinyldns.api.domain.dns.DnsProtocol._ import vinyldns.api.backend.dns.DnsProtocol._
import vinyldns.core.domain.{DomainHelpers, Fqdn, record}
import vinyldns.core.domain.record.RecordType._ import vinyldns.core.domain.record.RecordType._
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.{DomainHelpers, Fqdn, record}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.util.Try import scala.util.Try

View File

@ -20,7 +20,7 @@ import cats.implicits._
import com.aaronbedra.orchard.CIDR import com.aaronbedra.orchard.CIDR
import vinyldns.api.domain.zone.InvalidRequest import vinyldns.api.domain.zone.InvalidRequest
import vinyldns.core.domain.zone.Zone import vinyldns.core.domain.zone.Zone
import vinyldns.api.domain.dns.DnsConversions._ import vinyldns.api.backend.dns.DnsConversions._
import scala.util.Try import scala.util.Try

View File

@ -26,7 +26,7 @@ import vinyldns.api.domain.DomainValidations._
import vinyldns.api.domain.auth.AuthPrincipalProvider import vinyldns.api.domain.auth.AuthPrincipalProvider
import vinyldns.api.domain.batch.BatchChangeInterfaces._ import vinyldns.api.domain.batch.BatchChangeInterfaces._
import vinyldns.api.domain.batch.BatchTransformations._ import vinyldns.api.domain.batch.BatchTransformations._
import vinyldns.api.domain.dns.DnsConversions._ import vinyldns.api.backend.dns.DnsConversions._
import vinyldns.api.repository.ApiDataAccessor import vinyldns.api.repository.ApiDataAccessor
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.batch.BatchChangeApprovalStatus.BatchChangeApprovalStatus import vinyldns.core.domain.batch.BatchChangeApprovalStatus.BatchChangeApprovalStatus

View File

@ -23,7 +23,7 @@ import vinyldns.api.VinylDNSConfig
import vinyldns.api.domain.ReverseZoneHelpers import vinyldns.api.domain.ReverseZoneHelpers
import vinyldns.api.domain.batch.BatchChangeInterfaces.ValidatedBatch import vinyldns.api.domain.batch.BatchChangeInterfaces.ValidatedBatch
import vinyldns.api.domain.batch.BatchTransformations.LogicalChangeType.LogicalChangeType import vinyldns.api.domain.batch.BatchTransformations.LogicalChangeType.LogicalChangeType
import vinyldns.api.domain.dns.DnsConversions.getIPv6FullReverseName import vinyldns.api.backend.dns.DnsConversions.getIPv6FullReverseName
import vinyldns.core.domain.batch._ import vinyldns.core.domain.batch._
import vinyldns.core.domain.record.{AAAAData, RecordData, RecordSet, RecordSetChange} import vinyldns.core.domain.record.{AAAAData, RecordData, RecordSet, RecordSetChange}
import vinyldns.core.domain.record.RecordType._ import vinyldns.core.domain.record.RecordType._

View File

@ -19,8 +19,8 @@ package vinyldns.api.domain.record
import java.util.UUID import java.util.UUID
import org.joda.time.DateTime import org.joda.time.DateTime
import vinyldns.api.backend.dns.DnsConversions
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.api.domain.dns.DnsConversions
import vinyldns.core.domain.zone.Zone import vinyldns.core.domain.zone.Zone
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._

View File

@ -16,7 +16,7 @@
package vinyldns.api.domain.record package vinyldns.api.domain.record
import vinyldns.api.domain.dns.DnsConversions import vinyldns.api.backend.dns.DnsConversions
import vinyldns.core.domain.record.RecordSet import vinyldns.core.domain.record.RecordSet
object RecordSetHelpers { object RecordSetHelpers {

View File

@ -23,25 +23,24 @@ import vinyldns.api.domain.zone._
import vinyldns.api.repository.ApiDataAccessor import vinyldns.api.repository.ApiDataAccessor
import vinyldns.api.route.{ListGlobalRecordSetsResponse, ListRecordSetsByZoneResponse} import vinyldns.api.route.{ListGlobalRecordSetsResponse, ListRecordSetsByZoneResponse}
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone.{ConfiguredDnsConnections, Zone, ZoneCommandResult, ZoneRepository} import vinyldns.core.domain.zone.{Zone, ZoneCommandResult, ZoneRepository}
import vinyldns.core.queue.MessageQueue import vinyldns.core.queue.MessageQueue
import cats.data._ import cats.data._
import cats.effect.IO import cats.effect.IO
import org.xbill.DNS.ReverseMap import org.xbill.DNS.ReverseMap
import vinyldns.api.domain.DomainValidations.{validateIpv4Address, validateIpv6Address} import vinyldns.api.domain.DomainValidations.{validateIpv4Address, validateIpv6Address}
import vinyldns.api.domain.access.AccessValidationsAlgebra import vinyldns.api.domain.access.AccessValidationsAlgebra
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.core.domain.record.NameSort.NameSort import vinyldns.core.domain.record.NameSort.NameSort
import vinyldns.core.domain.record.RecordType.RecordType import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.DomainHelpers.ensureTrailingDot import vinyldns.core.domain.DomainHelpers.ensureTrailingDot
import vinyldns.core.domain.backend.{Backend, BackendResolver}
object RecordSetService { object RecordSetService {
def apply( def apply(
dataAccessor: ApiDataAccessor, dataAccessor: ApiDataAccessor,
messageQueue: MessageQueue, messageQueue: MessageQueue,
accessValidation: AccessValidationsAlgebra, accessValidation: AccessValidationsAlgebra,
dnsConnection: (Zone, ConfiguredDnsConnections) => DnsConnection, backendResolver: BackendResolver,
configuredDnsConnections: ConfiguredDnsConnections,
validateRecordLookupAgainstDnsBackend: Boolean validateRecordLookupAgainstDnsBackend: Boolean
): RecordSetService = ): RecordSetService =
new RecordSetService( new RecordSetService(
@ -52,8 +51,7 @@ object RecordSetService {
dataAccessor.userRepository, dataAccessor.userRepository,
messageQueue, messageQueue,
accessValidation, accessValidation,
dnsConnection, backendResolver,
configuredDnsConnections,
validateRecordLookupAgainstDnsBackend validateRecordLookupAgainstDnsBackend
) )
} }
@ -66,8 +64,7 @@ class RecordSetService(
userRepository: UserRepository, userRepository: UserRepository,
messageQueue: MessageQueue, messageQueue: MessageQueue,
accessValidation: AccessValidationsAlgebra, accessValidation: AccessValidationsAlgebra,
dnsConnection: (Zone, ConfiguredDnsConnections) => DnsConnection, backendResolver: BackendResolver,
configuredDnsConnections: ConfiguredDnsConnections,
validateRecordLookupAgainstDnsBackend: Boolean validateRecordLookupAgainstDnsBackend: Boolean
) extends RecordSetServiceAlgebra { ) extends RecordSetServiceAlgebra {
@ -82,8 +79,7 @@ class RecordSetService(
rsForValidations = change.recordSet rsForValidations = change.recordSet
_ <- isNotHighValueDomain(recordSet, zone).toResult _ <- isNotHighValueDomain(recordSet, zone).toResult
_ <- recordSetDoesNotExist( _ <- recordSetDoesNotExist(
dnsConnection, backendResolver.resolve,
configuredDnsConnections,
zone, zone,
rsForValidations, rsForValidations,
validateRecordLookupAgainstDnsBackend validateRecordLookupAgainstDnsBackend
@ -120,8 +116,7 @@ class RecordSetService(
.getRecordSetsByName(zone.id, rsForValidations.name) .getRecordSetsByName(zone.id, rsForValidations.name)
.toResult[List[RecordSet]] .toResult[List[RecordSet]]
_ <- isUniqueUpdate( _ <- isUniqueUpdate(
dnsConnection, backendResolver.resolve,
configuredDnsConnections,
rsForValidations, rsForValidations,
existingRecordsWithName, existingRecordsWithName,
zone, zone,
@ -375,8 +370,7 @@ class RecordSetService(
} }
def recordSetDoesNotExist( def recordSetDoesNotExist(
dnsConnection: (Zone, ConfiguredDnsConnections) => DnsConnection, backendConnection: Zone => Backend,
configuredDnsConnections: ConfiguredDnsConnections,
zone: Zone, zone: Zone,
recordSet: RecordSet, recordSet: RecordSet,
validateRecordLookupAgainstDnsBackend: Boolean validateRecordLookupAgainstDnsBackend: Boolean
@ -384,9 +378,9 @@ class RecordSetService(
recordSetDoesNotExistInDatabase(recordSet, zone).value.flatMap { recordSetDoesNotExistInDatabase(recordSet, zone).value.flatMap {
case Left(recordSetAlreadyExists: RecordSetAlreadyExists) case Left(recordSetAlreadyExists: RecordSetAlreadyExists)
if validateRecordLookupAgainstDnsBackend => if validateRecordLookupAgainstDnsBackend =>
dnsConnection(zone, configuredDnsConnections) backendConnection(zone)
.resolve(recordSet.name, zone.name, recordSet.typ) .resolve(recordSet.name, zone.name, recordSet.typ)
.value .attempt
.map { .map {
case Right(existingRecords) => case Right(existingRecords) =>
if (existingRecords.isEmpty) Right(()) if (existingRecords.isEmpty) Right(())
@ -397,8 +391,7 @@ class RecordSetService(
}.toResult }.toResult
def isUniqueUpdate( def isUniqueUpdate(
dnsConnection: (Zone, ConfiguredDnsConnections) => DnsConnection, backendConnection: Zone => Backend,
configuredDnsConnections: ConfiguredDnsConnections,
newRecordSet: RecordSet, newRecordSet: RecordSet,
existingRecordsWithName: List[RecordSet], existingRecordsWithName: List[RecordSet],
zone: Zone, zone: Zone,
@ -408,9 +401,9 @@ class RecordSetService(
.recordSetDoesNotExist(newRecordSet, existingRecordsWithName, zone) match { .recordSetDoesNotExist(newRecordSet, existingRecordsWithName, zone) match {
case Left(recordSetAlreadyExists: RecordSetAlreadyExists) case Left(recordSetAlreadyExists: RecordSetAlreadyExists)
if validateRecordLookupAgainstDnsBackend => if validateRecordLookupAgainstDnsBackend =>
dnsConnection(zone, configuredDnsConnections) backendConnection(zone)
.resolve(newRecordSet.name, zone.name, newRecordSet.typ) .resolve(newRecordSet.name, zone.name, newRecordSet.typ)
.value .attempt
.map { .map {
case Right(existingRecords) => case Right(existingRecords) =>
if (existingRecords.isEmpty) Right(()) if (existingRecords.isEmpty) Right(())

View File

@ -19,8 +19,8 @@ package vinyldns.api.domain.record
import cats.syntax.either._ import cats.syntax.either._
import vinyldns.api.Interfaces._ import vinyldns.api.Interfaces._
import vinyldns.api.VinylDNSConfig import vinyldns.api.VinylDNSConfig
import vinyldns.api.backend.dns.DnsConversions
import vinyldns.api.domain._ import vinyldns.api.domain._
import vinyldns.api.domain.dns.DnsConversions
import vinyldns.core.domain.DomainHelpers._ import vinyldns.core.domain.DomainHelpers._
import vinyldns.core.domain.record.RecordType._ import vinyldns.core.domain.record.RecordType._
import vinyldns.api.domain.zone._ import vinyldns.api.domain.zone._
@ -28,6 +28,7 @@ import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.Group import vinyldns.core.domain.membership.Group
import vinyldns.core.domain.record.{RecordSet, RecordType} import vinyldns.core.domain.record.{RecordSet, RecordType}
import vinyldns.core.domain.zone.Zone import vinyldns.core.domain.zone.Zone
import scala.util.matching.Regex import scala.util.matching.Regex
object RecordSetValidations { object RecordSetValidations {

View File

@ -16,76 +16,31 @@
package vinyldns.api.domain.zone package vinyldns.api.domain.zone
import java.net.{InetSocketAddress, Socket}
import cats.effect._ import cats.effect._
import cats.syntax.all._ import cats.syntax.all._
import org.slf4j.{Logger, LoggerFactory}
import vinyldns.api.Interfaces._ import vinyldns.api.Interfaces._
import vinyldns.api.VinylDNSConfig import vinyldns.api.VinylDNSConfig
import vinyldns.api.domain.dns.DnsConnection import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.record.{RecordSet, RecordType} import vinyldns.core.domain.record.RecordType
import vinyldns.core.domain.zone.{ConfiguredDnsConnections, DnsBackend, Zone, ZoneConnection} import vinyldns.core.domain.zone.Zone
import vinyldns.core.health.HealthCheck._
import scala.concurrent.duration._ import scala.concurrent.duration._
trait ZoneConnectionValidatorAlgebra { trait ZoneConnectionValidatorAlgebra {
def validateZoneConnections(zone: Zone): Result[Unit] def validateZoneConnections(zone: Zone): Result[Unit]
def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit] def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit]
} }
object ZoneConnectionValidator { class ZoneConnectionValidator(backendResolver: BackendResolver)
val logger: Logger = LoggerFactory.getLogger(classOf[ZoneConnectionValidator])
def getZoneConnection(
zone: Zone,
configuredDnsConnections: ConfiguredDnsConnections
): ZoneConnection =
zone.connection
.orElse(getDnsBackend(zone, configuredDnsConnections).map(_.zoneConnection))
.getOrElse(configuredDnsConnections.defaultZoneConnection)
def getTransferConnection(
zone: Zone,
configuredDnsConnections: ConfiguredDnsConnections
): ZoneConnection =
zone.transferConnection
.orElse(getDnsBackend(zone, configuredDnsConnections).map(_.transferConnection))
.getOrElse(configuredDnsConnections.defaultTransferConnection)
def getDnsBackend(
zone: Zone,
configuredDnsConnections: ConfiguredDnsConnections
): Option[DnsBackend] =
zone.backendId
.flatMap { bid =>
val backend = configuredDnsConnections.dnsBackends.find(_.id == bid)
if (backend.isEmpty) {
logger.error(
s"BackendId [$bid] for zone [${zone.id}: ${zone.name}] is not defined in config"
)
}
backend
}
}
class ZoneConnectionValidator(connections: ConfiguredDnsConnections)
extends ZoneConnectionValidatorAlgebra { extends ZoneConnectionValidatorAlgebra {
import ZoneConnectionValidator._
import ZoneRecordValidations._ import ZoneRecordValidations._
// Takes a long time to load large zones // Takes a long time to load large zones
val opTimeout: FiniteDuration = 60.seconds val opTimeout: FiniteDuration = 60.seconds
val (healthCheckAddress, healthCheckPort) = def loadDns(zone: Zone): IO[ZoneView] =
DnsConnection.parseHostAndPort(connections.defaultZoneConnection.primaryServer) DnsZoneViewLoader(zone, backendResolver.resolve(zone)).load()
def loadDns(zone: Zone): IO[ZoneView] = DnsZoneViewLoader(zone).load()
def hasApexNS(zoneView: ZoneView): Result[Unit] = { def hasApexNS(zoneView: ZoneView): Result[Unit] = {
val apexRecord = zoneView.recordSetsMap.get(zoneView.zone.name, RecordType.NS) match { val apexRecord = zoneView.recordSetsMap.get(zoneView.zone.name, RecordType.NS) match {
@ -107,8 +62,8 @@ class ZoneConnectionValidator(connections: ConfiguredDnsConnections)
.toResult .toResult
} }
def getDnsConnection(zone: Zone): Result[DnsConnection] = def getBackendConnection(zone: Zone): Result[Backend] =
Either.catchNonFatal(dnsConnection(getZoneConnection(zone, connections))).toResult backendResolver.resolve(zone).toResult
def loadZone(zone: Zone): Result[ZoneView] = def loadZone(zone: Zone): Result[ZoneView] =
withTimeout( withTimeout(
@ -117,22 +72,22 @@ class ZoneConnectionValidator(connections: ConfiguredDnsConnections)
ConnectionFailed(zone, "Unable to connect to zone: Transfer connection invalid") ConnectionFailed(zone, "Unable to connect to zone: Transfer connection invalid")
) )
def hasSOA(records: List[RecordSet], zone: Zone): Result[Unit] = { def zoneExists(zone: Zone, backend: Backend): Result[Unit] =
if (records.isEmpty) { backend
ConnectionFailed(zone, "SOA Record for zone not found").asLeft[Unit] .zoneExists(zone)
} else { .ifM(
().asRight[Throwable] IO(Right(())),
} IO(Left(ConnectionFailed(zone, s"Unable to find zone ${zone.name} in backend ${backend.id}")))
}.toResult )
.toResult
def validateZoneConnections(zone: Zone): Result[Unit] = { def validateZoneConnections(zone: Zone): Result[Unit] = {
val result = val result =
for { for {
connection <- getDnsConnection(zone) connection <- getBackendConnection(zone)
resp <- connection.resolve(zone.name, zone.name, RecordType.SOA) _ <- zoneExists(zone, connection)
view <- loadZone(zone) view <- loadZone(zone)
_ <- hasApexNS(view) _ <- hasApexNS(view)
_ <- hasSOA(resp, zone)
} yield () } yield ()
result.leftMap { result.leftMap {
@ -142,20 +97,8 @@ class ZoneConnectionValidator(connections: ConfiguredDnsConnections)
} }
} }
def healthCheck(timeout: Int): HealthCheck =
Resource
.fromAutoCloseable(IO(new Socket()))
.use(
socket =>
IO(socket.connect(new InetSocketAddress(healthCheckAddress, healthCheckPort), timeout))
)
.attempt
.asHealthCheck(classOf[ZoneConnectionValidator])
def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit] = def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit] =
ensuring(InvalidRequest(s"Invalid backendId: [$backendId]; please check system configuration")) { ensuring(InvalidRequest(s"Invalid backendId: [$backendId]; please check system configuration")) {
backendId.forall(id => connections.dnsBackends.exists(_.id == id)) backendId.forall(id => backendResolver.isRegistered(id))
} }
private[domain] def dnsConnection(conn: ZoneConnection): DnsConnection = DnsConnection(conn)
} }

View File

@ -18,13 +18,14 @@ package vinyldns.api.domain.zone
import cats.implicits._ import cats.implicits._
import vinyldns.api.domain.access.AccessValidationsAlgebra import vinyldns.api.domain.access.AccessValidationsAlgebra
import vinyldns.api.{Interfaces, VinylDNSConfig} import vinyldns.api.Interfaces
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.api.repository.ApiDataAccessor import vinyldns.api.repository.ApiDataAccessor
import vinyldns.core.domain.membership.{Group, GroupRepository, User, UserRepository} import vinyldns.core.domain.membership.{Group, GroupRepository, User, UserRepository}
import vinyldns.core.domain.zone._ import vinyldns.core.domain.zone._
import vinyldns.core.queue.MessageQueue import vinyldns.core.queue.MessageQueue
import vinyldns.core.domain.DomainHelpers.ensureTrailingDot import vinyldns.core.domain.DomainHelpers.ensureTrailingDot
import vinyldns.core.domain.backend.BackendResolver
object ZoneService { object ZoneService {
def apply( def apply(
@ -32,7 +33,8 @@ object ZoneService {
connectionValidator: ZoneConnectionValidatorAlgebra, connectionValidator: ZoneConnectionValidatorAlgebra,
messageQueue: MessageQueue, messageQueue: MessageQueue,
zoneValidations: ZoneValidations, zoneValidations: ZoneValidations,
accessValidation: AccessValidationsAlgebra accessValidation: AccessValidationsAlgebra,
backendResolver: BackendResolver
): ZoneService = ): ZoneService =
new ZoneService( new ZoneService(
dataAccessor.zoneRepository, dataAccessor.zoneRepository,
@ -42,7 +44,8 @@ object ZoneService {
connectionValidator, connectionValidator,
messageQueue, messageQueue,
zoneValidations, zoneValidations,
accessValidation accessValidation,
backendResolver
) )
} }
@ -54,7 +57,8 @@ class ZoneService(
connectionValidator: ZoneConnectionValidatorAlgebra, connectionValidator: ZoneConnectionValidatorAlgebra,
messageQueue: MessageQueue, messageQueue: MessageQueue,
zoneValidations: ZoneValidations, zoneValidations: ZoneValidations,
accessValidation: AccessValidationsAlgebra accessValidation: AccessValidationsAlgebra,
backendResolver: BackendResolver
) extends ZoneServiceAlgebra { ) extends ZoneServiceAlgebra {
import accessValidation._ import accessValidation._
@ -233,7 +237,7 @@ class ZoneService(
} }
def getBackendIds(): Result[List[String]] = def getBackendIds(): Result[List[String]] =
VinylDNSConfig.configuredDnsConnections.dnsBackends.map(_.id).toResult backendResolver.ids.toList.toResult
def zoneDoesNotExist(zoneName: String): Result[Unit] = def zoneDoesNotExist(zoneName: String): Result[Unit] =
zoneRepository zoneRepository

View File

@ -16,7 +16,7 @@
package vinyldns.api.domain.zone package vinyldns.api.domain.zone
import vinyldns.api.domain.dns.DnsConversions._ import vinyldns.api.backend.dns.DnsConversions._
import vinyldns.api.domain import vinyldns.api.domain
import vinyldns.api.domain.record.RecordSetChangeGenerator import vinyldns.api.domain.record.RecordSetChangeGenerator
import vinyldns.core.domain.record.{RecordSet, RecordSetChange} import vinyldns.core.domain.record.{RecordSet, RecordSetChange}

View File

@ -18,49 +18,24 @@ package vinyldns.api.domain.zone
import cats.effect._ import cats.effect._
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.xbill.DNS
import org.xbill.DNS.{TSIG, ZoneTransferIn}
import vinyldns.api.VinylDNSConfig import vinyldns.api.VinylDNSConfig
import vinyldns.api.crypto.Crypto import vinyldns.api.backend.dns.DnsConversions
import vinyldns.api.domain.dns.DnsConversions import vinyldns.core.domain.backend.Backend
import vinyldns.core.route.Monitored import vinyldns.core.domain.record.{NameSort, RecordSetRepository}
import scala.collection.JavaConverters._
import vinyldns.core.domain.record.{NameSort, RecordSetRepository, RecordType}
import vinyldns.core.domain.zone.Zone import vinyldns.core.domain.zone.Zone
import vinyldns.core.route.Monitored
trait ZoneViewLoader { trait ZoneViewLoader {
def load: () => IO[ZoneView] def load: () => IO[ZoneView]
} }
object DnsZoneViewLoader extends DnsConversions { object DnsZoneViewLoader extends DnsConversions {
val logger = LoggerFactory.getLogger("DnsZoneViewLoader") val logger = LoggerFactory.getLogger("DnsZoneViewLoader")
def dnsZoneTransfer(zone: Zone): ZoneTransferIn = {
val conn =
ZoneConnectionValidator
.getTransferConnection(zone, VinylDNSConfig.configuredDnsConnections)
.decrypted(Crypto.instance)
val TSIGKey = new TSIG(conn.keyName, conn.key)
val parts = conn.primaryServer.trim().split(':')
val (hostName, port) =
if (parts.length < 2)
(conn.primaryServer, 53)
else
(parts(0), parts(1).toInt)
val dnsZoneName = zoneDnsName(zone.name)
ZoneTransferIn.newAXFR(dnsZoneName, hostName, port, TSIGKey)
}
def apply(zone: Zone): DnsZoneViewLoader =
DnsZoneViewLoader(zone, dnsZoneTransfer)
} }
case class DnsZoneViewLoader( case class DnsZoneViewLoader(
zone: Zone, zone: Zone,
zoneTransfer: Zone => ZoneTransferIn, backendConnection: Backend,
maxZoneSize: Int = VinylDNSConfig.maxZoneSize maxZoneSize: Int = VinylDNSConfig.maxZoneSize
) extends ZoneViewLoader ) extends ZoneViewLoader
with DnsConversions with DnsConversions
@ -70,24 +45,7 @@ case class DnsZoneViewLoader(
() => () =>
monitor("dns.loadZoneView") { monitor("dns.loadZoneView") {
for { for {
zoneXfr <- IO { recordSets <- backendConnection.loadZone(zone, maxZoneSize)
val xfr = zoneTransfer(zone)
xfr.run()
xfr.getAXFR.asScala.map(_.asInstanceOf[DNS.Record]).toList.distinct
}
rawDnsRecords = zoneXfr.filter(
record => fromDnsRecordType(record.getType) != RecordType.UNKNOWN
)
_ <- if (rawDnsRecords.length > maxZoneSize)
IO.raiseError(ZoneTooLargeError(zone, rawDnsRecords.length, maxZoneSize))
else IO.pure(Unit)
dnsZoneName <- IO(zoneDnsName(zone.name))
recordSets <- IO(rawDnsRecords.map(toRecordSet(_, dnsZoneName, zone.id)))
_ <- IO(
DnsZoneViewLoader.logger.info(
s"dns.loadDnsView zoneName=${zone.name}; rawRsCount=${zoneXfr.size}; rsCount=${recordSets.size}"
)
)
} yield ZoneView(zone, recordSets) } yield ZoneView(zone, recordSets)
} }
} }

View File

@ -19,10 +19,10 @@ package vinyldns.api.engine
import cats.effect.{ContextShift, IO, Timer} import cats.effect.{ContextShift, IO, Timer}
import cats.implicits._ import cats.implicits._
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import vinyldns.api.domain.dns.DnsConnection import vinyldns.api.backend.dns.DnsProtocol.TryAgain
import vinyldns.api.domain.dns.DnsProtocol.{NoError, Refused, TryAgain}
import vinyldns.api.domain.record.RecordSetChangeGenerator import vinyldns.api.domain.record.RecordSetChangeGenerator
import vinyldns.api.domain.record.RecordSetHelpers._ import vinyldns.api.domain.record.RecordSetHelpers._
import vinyldns.core.domain.backend.{Backend, BackendResponse}
import vinyldns.core.domain.batch.{BatchChangeRepository, SingleChange} import vinyldns.core.domain.batch.{BatchChangeRepository, SingleChange}
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone.Zone import vinyldns.core.domain.zone.Zone
@ -39,7 +39,7 @@ object RecordSetChangeHandler {
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository, recordChangeRepository: RecordChangeRepository,
batchChangeRepository: BatchChangeRepository batchChangeRepository: BatchChangeRepository
)(implicit timer: Timer[IO]): (DnsConnection, RecordSetChange) => IO[RecordSetChange] = )(implicit timer: Timer[IO]): (Backend, RecordSetChange) => IO[RecordSetChange] =
(conn, recordSetChange) => { (conn, recordSetChange) => {
process( process(
recordSetRepository, recordSetRepository,
@ -54,7 +54,7 @@ object RecordSetChangeHandler {
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository, recordChangeRepository: RecordChangeRepository,
batchChangeRepository: BatchChangeRepository, batchChangeRepository: BatchChangeRepository,
conn: DnsConnection, conn: Backend,
recordSetChange: RecordSetChange recordSetChange: RecordSetChange
)(implicit timer: Timer[IO]): IO[RecordSetChange] = )(implicit timer: Timer[IO]): IO[RecordSetChange] =
for { for {
@ -124,7 +124,7 @@ object RecordSetChangeHandler {
def syncAndGetProcessingStatusFromDnsBackend( def syncAndGetProcessingStatusFromDnsBackend(
change: RecordSetChange, change: RecordSetChange,
dnsConn: DnsConnection, conn: Backend,
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository, recordChangeRepository: RecordChangeRepository,
performSync: Boolean = false performSync: Boolean = false
@ -167,7 +167,7 @@ object RecordSetChangeHandler {
} }
} }
dnsConn.resolve(change.recordSet.name, change.zone.name, change.recordSet.typ).value.flatMap { conn.resolve(change.recordSet.name, change.zone.name, change.recordSet.typ).attempt.flatMap {
case Right(existingRecords) => case Right(existingRecords) =>
if (performSync) { if (performSync) {
for { for {
@ -189,7 +189,7 @@ object RecordSetChangeHandler {
private def fsm( private def fsm(
state: ProcessorState, state: ProcessorState,
conn: DnsConnection, conn: Backend,
wildcardExists: Boolean, wildcardExists: Boolean,
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository recordChangeRepository: RecordChangeRepository
@ -310,13 +310,13 @@ object RecordSetChangeHandler {
/* Step 1: Validate the change hasn't already been applied */ /* Step 1: Validate the change hasn't already been applied */
private def validate( private def validate(
change: RecordSetChange, change: RecordSetChange,
dnsConn: DnsConnection, conn: Backend,
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository recordChangeRepository: RecordChangeRepository
): IO[ProcessorState] = ): IO[ProcessorState] =
syncAndGetProcessingStatusFromDnsBackend( syncAndGetProcessingStatusFromDnsBackend(
change, change,
dnsConn, conn,
recordSetRepository, recordSetRepository,
recordChangeRepository, recordChangeRepository,
true true
@ -333,12 +333,12 @@ object RecordSetChangeHandler {
} }
/* Step 2: Apply the change to the dns backend */ /* Step 2: Apply the change to the dns backend */
private def apply(change: RecordSetChange, dnsConn: DnsConnection): IO[ProcessorState] = private def apply(change: RecordSetChange, conn: Backend): IO[ProcessorState] =
dnsConn.applyChange(change).value.map { conn.applyChange(change).attempt.map {
case Right(_: NoError) => case Right(BackendResponse.Retry(_)) =>
Applied(change)
case Left(_: Refused) =>
Retrying(change) Retrying(change)
case Right(BackendResponse.NoError(_)) =>
Applied(change)
case Left(error) => case Left(error) =>
Completed( Completed(
change.failed( change.failed(
@ -350,13 +350,13 @@ object RecordSetChangeHandler {
/* Step 3: Verify the record was created. If the ProcessorState is applied or failed we requeue the record.*/ /* Step 3: Verify the record was created. If the ProcessorState is applied or failed we requeue the record.*/
private def verify( private def verify(
change: RecordSetChange, change: RecordSetChange,
dnsConn: DnsConnection, conn: Backend,
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository recordChangeRepository: RecordChangeRepository
): IO[ProcessorState] = ): IO[ProcessorState] =
syncAndGetProcessingStatusFromDnsBackend( syncAndGetProcessingStatusFromDnsBackend(
change, change,
dnsConn, conn,
recordSetRepository, recordSetRepository,
recordChangeRepository recordChangeRepository
).map { ).map {

View File

@ -20,8 +20,9 @@ import cats.effect.{ContextShift, IO}
import cats.syntax.all._ import cats.syntax.all._
import org.joda.time.DateTime import org.joda.time.DateTime
import org.slf4j.{Logger, LoggerFactory} import org.slf4j.{Logger, LoggerFactory}
import vinyldns.api.domain.dns.DnsConversions import vinyldns.api.backend.dns.DnsConversions
import vinyldns.api.domain.zone.{DnsZoneViewLoader, VinylDNSZoneViewLoader} import vinyldns.api.domain.zone.{DnsZoneViewLoader, VinylDNSZoneViewLoader}
import vinyldns.core.domain.backend.BackendResolver
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone.{Zone, ZoneStatus} import vinyldns.core.domain.zone.{Zone, ZoneStatus}
import vinyldns.core.route.Monitored import vinyldns.core.route.Monitored
@ -43,7 +44,7 @@ object ZoneSyncHandler extends DnsConversions with Monitored {
recordChangeRepository: RecordChangeRepository, recordChangeRepository: RecordChangeRepository,
zoneChangeRepository: ZoneChangeRepository, zoneChangeRepository: ZoneChangeRepository,
zoneRepository: ZoneRepository, zoneRepository: ZoneRepository,
dnsLoader: Zone => DnsZoneViewLoader = DnsZoneViewLoader.apply, backendResolver: BackendResolver,
vinyldnsLoader: (Zone, RecordSetRepository) => VinylDNSZoneViewLoader = vinyldnsLoader: (Zone, RecordSetRepository) => VinylDNSZoneViewLoader =
VinylDNSZoneViewLoader.apply VinylDNSZoneViewLoader.apply
): ZoneChange => IO[ZoneChange] = ): ZoneChange => IO[ZoneChange] =
@ -55,7 +56,7 @@ object ZoneSyncHandler extends DnsConversions with Monitored {
recordSetRepository, recordSetRepository,
recordChangeRepository, recordChangeRepository,
zoneChange, zoneChange,
dnsLoader, backendResolver,
vinyldnsLoader vinyldnsLoader
) )
_ <- saveZoneAndChange(zoneRepository, zoneChangeRepository, syncChange) // final save to store zone status _ <- saveZoneAndChange(zoneRepository, zoneChangeRepository, syncChange) // final save to store zone status
@ -83,18 +84,18 @@ object ZoneSyncHandler extends DnsConversions with Monitored {
recordSetRepository: RecordSetRepository, recordSetRepository: RecordSetRepository,
recordChangeRepository: RecordChangeRepository, recordChangeRepository: RecordChangeRepository,
zoneChange: ZoneChange, zoneChange: ZoneChange,
dnsLoader: Zone => DnsZoneViewLoader = DnsZoneViewLoader.apply, backendResolver: BackendResolver,
vinyldnsLoader: (Zone, RecordSetRepository) => VinylDNSZoneViewLoader = vinyldnsLoader: (Zone, RecordSetRepository) => VinylDNSZoneViewLoader =
VinylDNSZoneViewLoader.apply VinylDNSZoneViewLoader.apply
): IO[ZoneChange] = ): IO[ZoneChange] =
monitor("zone.sync") { monitor("zone.sync") {
time(s"zone.sync; zoneName='${zoneChange.zone.name}'") { time(s"zone.sync; zoneName='${zoneChange.zone.name}'") {
val zone = zoneChange.zone val zone = zoneChange.zone
val dnsLoader = DnsZoneViewLoader(zone, backendResolver.resolve(zone))
val dnsView = val dnsView =
time( time(
s"zone.sync.loadDnsView; zoneName='${zone.name}'; zoneChange='${zoneChange.id}'" s"zone.sync.loadDnsView; zoneName='${zone.name}'; zoneChange='${zoneChange.id}'"
)(dnsLoader(zone).load()) )(dnsLoader.load())
val vinyldnsView = time(s"zone.sync.loadVinylDNSView; zoneName='${zone.name}'")( val vinyldnsView = time(s"zone.sync.loadVinylDNSView; zoneName='${zone.name}'")(
vinyldnsLoader(zone, recordSetRepository).load() vinyldnsLoader(zone, recordSetRepository).load()
) )

View File

@ -34,7 +34,7 @@ case class CurrentStatus(
object CurrentStatus { object CurrentStatus {
val color = VinylDNSConfig.vinyldnsConfig.getString("color") val color = VinylDNSConfig.vinyldnsConfig.getString("color")
val vinyldnsKeyName = VinylDNSConfig.configuredDnsConnections.defaultZoneConnection.keyName val vinyldnsKeyName = "vinyldns."
val version = VinylDNSConfig.vinyldnsConfig.getString("version") val version = VinylDNSConfig.vinyldnsConfig.getString("version")
} }

View File

@ -21,20 +21,21 @@ import fs2._
import org.mockito import org.mockito
import org.mockito.Matchers._ import org.mockito.Matchers._
import org.mockito.Mockito._ import org.mockito.Mockito._
import org.mockito.{ArgumentCaptor, Mockito} import org.mockito.Mockito
import org.scalatestplus.mockito.MockitoSugar import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.{BeforeAndAfterEach, EitherValues} import org.scalatest.{BeforeAndAfterEach, EitherValues}
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import vinyldns.api.VinylDNSTestHelpers import vinyldns.api.VinylDNSTestHelpers
import vinyldns.api.backend.CommandHandler.{DeleteMessage, RetryMessage} import vinyldns.api.backend.CommandHandler.{DeleteMessage, RetryMessage}
import vinyldns.api.domain.dns.DnsConnection import vinyldns.api.backend.dns.DnsBackend
import vinyldns.core.domain.batch.{BatchChange, BatchChangeCommand, BatchChangeRepository} import vinyldns.core.domain.batch.{BatchChange, BatchChangeCommand, BatchChangeRepository}
import vinyldns.core.domain.record.{RecordChangeRepository, RecordSetChange, RecordSetRepository} import vinyldns.core.domain.record.{RecordChangeRepository, RecordSetChange, RecordSetRepository}
import vinyldns.core.domain.zone.{ZoneChange, ZoneChangeType, ZoneCommand, _} import vinyldns.core.domain.zone.{ZoneChange, ZoneChangeType, ZoneCommand, _}
import vinyldns.core.queue.{CommandMessage, MessageCount, MessageId, MessageQueue} import vinyldns.core.queue.{CommandMessage, MessageCount, MessageId, MessageQueue}
import vinyldns.core.TestRecordSetData._ import vinyldns.core.TestRecordSetData._
import vinyldns.core.TestZoneData._ import vinyldns.core.TestZoneData._
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext
import scala.concurrent.duration._ import scala.concurrent.duration._
@ -66,19 +67,17 @@ class CommandHandlerSpec
private val mockZoneChangeProcessor = mock[ZoneChange => IO[ZoneChange]] private val mockZoneChangeProcessor = mock[ZoneChange => IO[ZoneChange]]
private val mockRecordChangeProcessor = private val mockRecordChangeProcessor =
mock[(DnsConnection, RecordSetChange) => IO[RecordSetChange]] mock[(Backend, RecordSetChange) => IO[RecordSetChange]]
private val mockZoneSyncProcessor = mock[ZoneChange => IO[ZoneChange]] private val mockZoneSyncProcessor = mock[ZoneChange => IO[ZoneChange]]
private val mockBatchChangeProcessor = mock[BatchChangeCommand => IO[Option[BatchChange]]] private val mockBatchChangeProcessor = mock[BatchChangeCommand => IO[Option[BatchChange]]]
private val defaultConn = private val mockBackendResolver = mock[BackendResolver]
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")
private val connections = ConfiguredDnsConnections(defaultConn, defaultConn, List())
private val processor = private val processor =
CommandHandler.processChangeRequests( CommandHandler.processChangeRequests(
mockZoneChangeProcessor, mockZoneChangeProcessor,
mockRecordChangeProcessor, mockRecordChangeProcessor,
mockZoneSyncProcessor, mockZoneSyncProcessor,
mockBatchChangeProcessor, mockBatchChangeProcessor,
connections mockBackendResolver
) )
override protected def beforeEach(): Unit = override protected def beforeEach(): Unit =
@ -219,37 +218,12 @@ class CommandHandlerSpec
val change = TestCommandMessage(pendingCreateAAAA, "foo") val change = TestCommandMessage(pendingCreateAAAA, "foo")
doReturn(IO.pure(change)) doReturn(IO.pure(change))
.when(mockRecordChangeProcessor) .when(mockRecordChangeProcessor)
.apply(any[DnsConnection], any[RecordSetChange]) .apply(any[DnsBackend], any[RecordSetChange])
Stream.emit(change).covary[IO].through(processor).compile.drain.unsafeRunSync() Stream.emit(change).covary[IO].through(processor).compile.drain.unsafeRunSync()
verify(mockRecordChangeProcessor).apply(any[DnsConnection], any[RecordSetChange]) verify(mockRecordChangeProcessor).apply(any[DnsBackend], any[RecordSetChange])
verifyZeroInteractions(mockZoneSyncProcessor) verifyZeroInteractions(mockZoneSyncProcessor)
verifyZeroInteractions(mockZoneChangeProcessor) verifyZeroInteractions(mockZoneChangeProcessor)
} }
"use the default zone connection when the change zone connection is not defined" in {
val noConnChange =
pendingCreateAAAA.copy(
zone = pendingCreateAAAA.zone.copy(connection = None, transferConnection = None)
)
val default = defaultConn.copy(primaryServer = "default.conn.test.com")
val defaultConnProcessor =
CommandHandler.processChangeRequests(
mockZoneChangeProcessor,
mockRecordChangeProcessor,
mockZoneSyncProcessor,
mockBatchChangeProcessor,
ConfiguredDnsConnections(default, default, List())
)
val change = TestCommandMessage(noConnChange, "foo")
doReturn(IO.pure(change))
.when(mockRecordChangeProcessor)
.apply(any[DnsConnection], any[RecordSetChange])
Stream.emit(change).covary[IO].through(defaultConnProcessor).compile.drain.unsafeRunSync()
val connCaptor = ArgumentCaptor.forClass(classOf[DnsConnection])
verify(mockRecordChangeProcessor).apply(connCaptor.capture(), any[RecordSetChange])
val resolver = connCaptor.getValue.resolver
resolver.getAddress.getHostName shouldBe default.primaryServer
}
"handle zone creates" in { "handle zone creates" in {
val change = TestCommandMessage(zoneCreate, "foo") val change = TestCommandMessage(zoneCreate, "foo")
doReturn(IO.pure(zoneCreate)) doReturn(IO.pure(zoneCreate))
@ -305,7 +279,11 @@ class CommandHandlerSpec
// stage our record change processing // stage our record change processing
doReturn(IO.pure(cmd)) doReturn(IO.pure(cmd))
.when(mockRecordChangeProcessor) .when(mockRecordChangeProcessor)
.apply(any[DnsConnection], any[RecordSetChange]) .apply(any[Backend], any[RecordSetChange])
doReturn(mock[Backend])
.when(mockBackendResolver)
.resolve(any[Zone])
// stage removing from the queue // stage removing from the queue
doReturn(IO.unit).when(mq).remove(cmd) doReturn(IO.unit).when(mq).remove(cmd)
@ -321,7 +299,7 @@ class CommandHandlerSpec
count, count,
100.millis, 100.millis,
stop, stop,
connections, mockBackendResolver,
1 1
) )
.take(1) .take(1)
@ -332,7 +310,7 @@ class CommandHandlerSpec
// verify our interactions // verify our interactions
verify(mq, atLeastOnce()).receive(count) verify(mq, atLeastOnce()).receive(count)
verify(mockRecordChangeProcessor) verify(mockRecordChangeProcessor)
.apply(any[DnsConnection], mockito.Matchers.eq(pendingCreateAAAA)) .apply(any[DnsBackend], mockito.Matchers.eq(pendingCreateAAAA))
verify(mq).remove(cmd) verify(mq).remove(cmd)
} }
"continue processing on unexpected failure" in { "continue processing on unexpected failure" in {
@ -348,7 +326,11 @@ class CommandHandlerSpec
// stage our record change processing our command // stage our record change processing our command
doReturn(IO.pure(cmd.command)) doReturn(IO.pure(cmd.command))
.when(mockRecordChangeProcessor) .when(mockRecordChangeProcessor)
.apply(any[DnsConnection], any[RecordSetChange]) .apply(any[Backend], any[RecordSetChange])
doReturn(mock[Backend])
.when(mockBackendResolver)
.resolve(any[Zone])
// stage removing from the queue // stage removing from the queue
doReturn(IO.unit).when(mq).remove(cmd) doReturn(IO.unit).when(mq).remove(cmd)
@ -364,7 +346,7 @@ class CommandHandlerSpec
count, count,
100.millis, 100.millis,
stop, stop,
connections mockBackendResolver
) )
.take(1) .take(1)
@ -403,6 +385,9 @@ class CommandHandlerSpec
.receive(count) .receive(count)
// stage processing for a zone update, the simplest of cases // stage processing for a zone update, the simplest of cases
doReturn(mock[Backend])
.when(mockBackendResolver)
.resolve(any[Zone])
doReturn(IO.pure(Right(zoneUpdate.zone))).when(zoneRepo).save(zoneUpdate.zone) doReturn(IO.pure(Right(zoneUpdate.zone))).when(zoneRepo).save(zoneUpdate.zone)
doReturn(IO.pure(zoneUpdate)).when(zoneChangeRepo).save(any[ZoneChange]) doReturn(IO.pure(zoneUpdate)).when(zoneChangeRepo).save(any[ZoneChange])
@ -422,7 +407,7 @@ class CommandHandlerSpec
recordChangeRepo, recordChangeRepo,
batchChangeRepo, batchChangeRepo,
AllNotifiers(List.empty), AllNotifiers(List.empty),
connections mockBackendResolver
) )
// kick off processing of messages // kick off processing of messages

View File

@ -14,37 +14,37 @@
* limitations under the License. * limitations under the License.
*/ */
package vinyldns.api.domain.dns package vinyldns.api.backend.dns
import java.net.InetAddress import java.net.{InetAddress, SocketAddress}
import cats.scalatest.EitherMatchers import cats.scalatest.EitherMatchers
import org.joda.time.DateTime import org.joda.time.DateTime
import org.mockito.ArgumentCaptor import org.mockito.ArgumentCaptor
import org.mockito.Matchers._ import org.mockito.Matchers._
import org.mockito.Mockito._ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.{BeforeAndAfterEach, EitherValues}
import org.scalatestplus.mockito.MockitoSugar
import org.xbill.DNS import org.xbill.DNS
import org.xbill.DNS.{Lookup, Name} import org.xbill.DNS.{Lookup, Name, TSIG}
import vinyldns.api.ResultHelpers import vinyldns.api.backend.dns.DnsProtocol._
import vinyldns.api.domain.dns.DnsProtocol._ import vinyldns.core.crypto.{CryptoAlgebra, NoOpCrypto}
import vinyldns.core.domain.backend.BackendResponse
import vinyldns.core.domain.record.RecordType._ import vinyldns.core.domain.record.RecordType._
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.zone.{Zone, ZoneConnection} import vinyldns.core.domain.zone.{Zone, ZoneConnection}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
class DnsConnectionSpec class DnsBackendSpec
extends AnyWordSpec extends AnyWordSpec
with Matchers with Matchers
with MockitoSugar with MockitoSugar
with ResultHelpers
with BeforeAndAfterEach with BeforeAndAfterEach
with EitherMatchers { with EitherMatchers
with EitherValues {
private val zoneConnection = private val zoneConnection =
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1") ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")
@ -82,7 +82,10 @@ class DnsConnectionSpec
private val mockMessage = mock[DNS.Message] private val mockMessage = mock[DNS.Message]
private val messageCaptor = ArgumentCaptor.forClass(classOf[DNS.Message]) private val messageCaptor = ArgumentCaptor.forClass(classOf[DNS.Message])
private val mockDnsQuery = mock[DnsQuery] private val mockDnsQuery = mock[DnsQuery]
private val underTest = new DnsConnection(mockResolver) { private val mockSocketAddress = mock[SocketAddress]
private val mockTsig = mock[TSIG]
private val transferInfo = TransferInfo(mockSocketAddress, mockTsig)
private val underTest = new DnsBackend("test", mockResolver, transferInfo) {
override def toQuery( override def toQuery(
name: String, name: String,
zoneName: String, zoneName: String,
@ -94,7 +97,7 @@ class DnsConnectionSpec
case _ => Right(mockDnsQuery) case _ => Right(mockDnsQuery)
} }
} }
private val dnsQueryTest = new DnsConnection(mockResolver) private val dnsQueryTest = new DnsBackend("query-test", mockResolver, transferInfo)
override def beforeEach(): Unit = { override def beforeEach(): Unit = {
doReturn(mockMessage).when(mockMessage).clone() doReturn(mockMessage).when(mockMessage).clone()
@ -125,7 +128,7 @@ class DnsConnectionSpec
"Creating a Dns Connection" should { "Creating a Dns Connection" should {
"decrypt the zone connection" in { "decrypt the zone connection" in {
val conn = spy(zoneConnection) val conn = spy(zoneConnection)
DnsConnection(conn) DnsBackend("test", conn, None, new NoOpCrypto())
verify(conn).decrypted(any[CryptoAlgebra]) verify(conn).decrypted(any[CryptoAlgebra])
} }
@ -133,7 +136,7 @@ class DnsConnectionSpec
"parse the port when specified on the primary server" in { "parse the port when specified on the primary server" in {
val conn = zoneConnection.copy(primaryServer = "dns.comcast.net:19001") val conn = zoneConnection.copy(primaryServer = "dns.comcast.net:19001")
val dnsConn = DnsConnection(conn) val dnsConn = DnsBackend("test", conn, None, new NoOpCrypto())
val simpleResolver = dnsConn.resolver.asInstanceOf[DNS.SimpleResolver] val simpleResolver = dnsConn.resolver.asInstanceOf[DNS.SimpleResolver]
val address = simpleResolver.getAddress val address = simpleResolver.getAddress
@ -145,7 +148,7 @@ class DnsConnectionSpec
"use default port of 53 when not specified" in { "use default port of 53 when not specified" in {
val conn = zoneConnection.copy(primaryServer = "dns.comcast.net") val conn = zoneConnection.copy(primaryServer = "dns.comcast.net")
val dnsConn = DnsConnection(conn) val dnsConn = DnsBackend("test", conn, None, new NoOpCrypto())
val simpleResolver = dnsConn.resolver.asInstanceOf[DNS.SimpleResolver] val simpleResolver = dnsConn.resolver.asInstanceOf[DNS.SimpleResolver]
val address = simpleResolver.getAddress val address = simpleResolver.getAddress
@ -158,7 +161,7 @@ class DnsConnectionSpec
"Resolving records" should { "Resolving records" should {
"return a single record when only one DNS record is returned" in { "return a single record when only one DNS record is returned" in {
val records: List[RecordSet] = val records: List[RecordSet] =
rightResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) underTest.resolve("www", "vinyldns.", RecordType.A).unsafeRunSync()
records.head should have( records.head should have(
'name ("a-record."), 'name ("a-record."),
'typ (RecordType.A), 'typ (RecordType.A),
@ -182,7 +185,7 @@ class DnsConnectionSpec
doReturn(List(a1, a2)).when(mockDnsQuery).run() doReturn(List(a1, a2)).when(mockDnsQuery).run()
val records: List[RecordSet] = val records: List[RecordSet] =
rightResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) underTest.resolve("www", "vinyldns.", RecordType.A).unsafeRunSync()
records.head should have( records.head should have(
'name ("a-record."), 'name ("a-record."),
'typ (RecordType.A), 'typ (RecordType.A),
@ -194,14 +197,15 @@ class DnsConnectionSpec
doReturn(DNS.Lookup.HOST_NOT_FOUND).when(mockDnsQuery).result doReturn(DNS.Lookup.HOST_NOT_FOUND).when(mockDnsQuery).result
val records: List[RecordSet] = val records: List[RecordSet] =
rightResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) underTest.resolve("www", "vinyldns.", RecordType.A).unsafeRunSync()
records shouldBe empty records shouldBe empty
} }
"return an Uncrecoverable error" in { "return an Uncrecoverable error" in {
doReturn(DNS.Lookup.UNRECOVERABLE).when(mockDnsQuery).result doReturn(DNS.Lookup.UNRECOVERABLE).when(mockDnsQuery).result
val error = leftResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) val error =
underTest.resolve("www", "vinyldns.", RecordType.A).attempt.unsafeRunSync().left.value
error shouldBe a[Unrecoverable] error shouldBe a[Unrecoverable]
} }
@ -209,7 +213,8 @@ class DnsConnectionSpec
doReturn("this is bad").when(mockDnsQuery).error doReturn("this is bad").when(mockDnsQuery).error
doReturn(DNS.Lookup.TRY_AGAIN).when(mockDnsQuery).result doReturn(DNS.Lookup.TRY_AGAIN).when(mockDnsQuery).result
val error = leftResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) val error =
underTest.resolve("www", "vinyldns.", RecordType.A).attempt.unsafeRunSync().left.value
error shouldBe a[TryAgain] error shouldBe a[TryAgain]
} }
@ -217,7 +222,7 @@ class DnsConnectionSpec
doReturn(DNS.Lookup.TYPE_NOT_FOUND).when(mockDnsQuery).result doReturn(DNS.Lookup.TYPE_NOT_FOUND).when(mockDnsQuery).result
val result: List[RecordSet] = val result: List[RecordSet] =
rightResultOf(underTest.resolve("www", "vinyldns.", RecordType.A).value) underTest.resolve("www", "vinyldns.", RecordType.A).unsafeRunSync()
result shouldBe List() result shouldBe List()
} }
@ -227,14 +232,15 @@ class DnsConnectionSpec
"return an InvalidRecord error if there are no records present" in { "return an InvalidRecord error if there are no records present" in {
val noRecords = testA.copy(records = Nil) val noRecords = testA.copy(records = Nil)
val result = leftResultOf(underTest.addRecord(addRsChange(testZone, noRecords)).value) val result =
underTest.addRecord(addRsChange(testZone, noRecords)).attempt.unsafeRunSync().left.value
result shouldBe a[InvalidRecord] result shouldBe a[InvalidRecord]
} }
"send an appropriate update message to the resolver" in { "send an appropriate update message to the resolver" in {
val change = addRsChange() val change = addRsChange()
val result: DnsResponse = rightResultOf(underTest.addRecord(change).value) val result: DnsResponse = underTest.addRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -253,7 +259,7 @@ class DnsConnectionSpec
"send an appropriate update message to the resolver when multiple record sets are present" in { "send an appropriate update message to the resolver when multiple record sets are present" in {
val change = addRsChange(testZone, testAMultiple) val change = addRsChange(testZone, testAMultiple)
val result: DnsResponse = rightResultOf(underTest.addRecord(change).value) val result: DnsResponse = underTest.addRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -280,14 +286,20 @@ class DnsConnectionSpec
"return an InvalidRecord error if there are no records present" in { "return an InvalidRecord error if there are no records present" in {
val noRecords = testA.copy(records = Nil) val noRecords = testA.copy(records = Nil)
val result = leftResultOf(underTest.updateRecord(updateRsChange(testZone, noRecords)).value) val result =
underTest
.updateRecord(updateRsChange(testZone, noRecords))
.attempt
.unsafeRunSync()
.left
.value
result shouldBe a[InvalidRecord] result shouldBe a[InvalidRecord]
} }
"send an appropriate replace message to the resolver for a name change" in { "send an appropriate replace message to the resolver for a name change" in {
val change = updateRsChange().copy(updates = Some(testA.copy(name = "updated-a-record"))) val change = updateRsChange().copy(updates = Some(testA.copy(name = "updated-a-record")))
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -313,7 +325,7 @@ class DnsConnectionSpec
"send an appropriate replace message to the resolver for a TTL change" in { "send an appropriate replace message to the resolver for a TTL change" in {
val change = updateRsChange(rs = testA.copy(ttl = 300)).copy(updates = Some(testA)) val change = updateRsChange(rs = testA.copy(ttl = 300)).copy(updates = Some(testA))
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -339,7 +351,7 @@ class DnsConnectionSpec
"send an appropriate replace message in the event that the record being replaced is None" in { "send an appropriate replace message in the event that the record being replaced is None" in {
val change = updateRsChange().copy(updates = None) val change = updateRsChange().copy(updates = None)
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -353,7 +365,7 @@ class DnsConnectionSpec
updates = Some(testAMultiple.copy(name = "updated-a-record")) updates = Some(testAMultiple.copy(name = "updated-a-record"))
) )
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -391,14 +403,20 @@ class DnsConnectionSpec
"return an InvalidRecord error if there are no records present" in { "return an InvalidRecord error if there are no records present" in {
val noRecords = testA.copy(records = Nil) val noRecords = testA.copy(records = Nil)
val result = leftResultOf(underTest.updateRecord(updateRsChange(testZone, noRecords)).value) val result =
underTest
.updateRecord(updateRsChange(testZone, noRecords))
.attempt
.unsafeRunSync()
.left
.value
result shouldBe a[InvalidRecord] result shouldBe a[InvalidRecord]
} }
"send a message with an empty body to the resolver when no changes have occurred" in { "send a message with an empty body to the resolver when no changes have occurred" in {
val change = updateRsChange().copy(updates = Some(testA)) val change = updateRsChange().copy(updates = Some(testA))
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -411,7 +429,7 @@ class DnsConnectionSpec
val change = val change =
updateRsChange().copy(updates = Some(testA.copy(records = List(AData("127.0.0.1"))))) updateRsChange().copy(updates = Some(testA.copy(records = List(AData("127.0.0.1")))))
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -437,7 +455,7 @@ class DnsConnectionSpec
"send an appropriate replace message in the event that the record being replaced is None" in { "send an appropriate replace message in the event that the record being replaced is None" in {
val change = updateRsChange().copy(updates = None) val change = updateRsChange().copy(updates = None)
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -451,7 +469,7 @@ class DnsConnectionSpec
updates = Some(testAMultiple.copy(records = List(AData("4.4.4.4"), AData("3.3.3.3")))) updates = Some(testAMultiple.copy(records = List(AData("4.4.4.4"), AData("3.3.3.3"))))
) )
val result: DnsResponse = rightResultOf(underTest.updateRecord(change).value) val result: DnsResponse = underTest.updateRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -498,14 +516,20 @@ class DnsConnectionSpec
"return an InvalidRecord error if there are no records present in the delete" in { "return an InvalidRecord error if there are no records present in the delete" in {
val noRecords = testA.copy(records = Nil) val noRecords = testA.copy(records = Nil)
val result = leftResultOf(underTest.updateRecord(deleteRsChange(testZone, noRecords)).value) val result =
underTest
.updateRecord(deleteRsChange(testZone, noRecords))
.attempt
.unsafeRunSync()
.left
.value
result shouldBe a[InvalidRecord] result shouldBe a[InvalidRecord]
} }
"send an appropriate delete message to the resolver" in { "send an appropriate delete message to the resolver" in {
val change = deleteRsChange() val change = deleteRsChange()
val result: DnsResponse = rightResultOf(underTest.deleteRecord(change).value) val result: DnsResponse = underTest.deleteRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -524,7 +548,7 @@ class DnsConnectionSpec
"send an appropriate delete message to the resolver for multiple records" in { "send an appropriate delete message to the resolver for multiple records" in {
val change = deleteRsChange(testZone, testAMultiple) val change = deleteRsChange(testZone, testAMultiple)
val result: DnsResponse = rightResultOf(underTest.deleteRecord(change).value) val result: DnsResponse = underTest.deleteRecord(change).unsafeRunSync()
val sentMessage = messageCaptor.getValue val sentMessage = messageCaptor.getValue
@ -544,21 +568,17 @@ class DnsConnectionSpec
"applyChange" should { "applyChange" should {
"yield a successful DNS response for a create if there are no errors" in { "yield a successful DNS response for a create if there are no errors" in {
underTest.applyChange(addRsChange()).value.unsafeRunSync() shouldBe Right( underTest.applyChange(addRsChange()).unsafeRunSync() shouldBe a[BackendResponse.NoError]
NoError(mockMessage)
)
} }
"yield a successful DNS response for an update if there are no errors" in { "yield a successful DNS response for an update if there are no errors" in {
underTest.applyChange(updateRsChange()).value.unsafeRunSync() shouldBe Right( underTest.applyChange(updateRsChange()).unsafeRunSync() shouldBe a[BackendResponse.NoError]
NoError(mockMessage)
)
} }
"yield a successful DNS response for a delete if there are no errors" in { "yield a successful DNS response for a delete if there are no errors" in {
underTest.applyChange(deleteRsChange()).value.unsafeRunSync() shouldBe Right( underTest
NoError(mockMessage) .applyChange(deleteRsChange())
) .unsafeRunSync() shouldBe a[BackendResponse.NoError]
} }
} }
@ -568,7 +588,7 @@ class DnsConnectionSpec
underTest underTest
.resolve(rsc.recordSet.name, rsc.zone.name, rsc.recordSet.typ) .resolve(rsc.recordSet.name, rsc.zone.name, rsc.recordSet.typ)
.value .attempt
.unsafeRunSync() shouldBe left .unsafeRunSync() shouldBe left
} }
} }

View File

@ -14,23 +14,22 @@
* limitations under the License. * limitations under the License.
*/ */
package vinyldns.api.domain.dns package vinyldns.api.backend.dns
import java.net.InetAddress import java.net.InetAddress
import org.joda.time.DateTime import org.joda.time.DateTime
import org.mockito.Mockito._ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.{BeforeAndAfterEach, EitherValues}
import org.scalatestplus.mockito.MockitoSugar
import org.xbill.DNS import org.xbill.DNS
import vinyldns.api.ResultHelpers import vinyldns.api.backend.dns.DnsProtocol._
import vinyldns.api.domain.dns.DnsProtocol._
import vinyldns.core.domain.record._
import vinyldns.core.domain.zone.Zone
import vinyldns.core.TestRecordSetData.ds import vinyldns.core.TestRecordSetData.ds
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.record._
import vinyldns.core.domain.zone.Zone
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
@ -38,9 +37,9 @@ class DnsConversionsSpec
extends AnyWordSpec extends AnyWordSpec
with Matchers with Matchers
with MockitoSugar with MockitoSugar
with ResultHelpers
with BeforeAndAfterEach with BeforeAndAfterEach
with DnsConversions { with DnsConversions
with EitherValues {
private val testZoneName = "vinyldns." private val testZoneName = "vinyldns."
private val testZone = Zone(testZoneName, "test@test.com") private val testZone = Zone(testZoneName, "test@test.com")
@ -283,7 +282,7 @@ class DnsConversionsSpec
} }
private def roundTrip(rs: RecordSet): RecordSet = { private def roundTrip(rs: RecordSet): RecordSet = {
val recordList = rightValue(toDnsRecords(rs, testZoneName)).map(toRecordSet(_, testZoneDnsName)) val recordList = toDnsRecords(rs, testZoneName).right.value.map(toRecordSet(_, testZoneDnsName))
recordList.head.copy(records = recordList.flatMap(_.records)) recordList.head.copy(records = recordList.flatMap(_.records))
} }
@ -324,83 +323,83 @@ class DnsConversionsSpec
"Converting to a DNS RRset" should { "Converting to a DNS RRset" should {
"convert A record set" in { "convert A record set" in {
val result = rightValue(toDnsRRset(testA, testZoneName)) val result = toDnsRRset(testA, testZoneName).right.value
verifyMatch(result, testA) verifyMatch(result, testA)
} }
"convert multiple record set" in { "convert multiple record set" in {
val result = rightValue(toDnsRRset(testAMultiple, testZoneName)) val result = toDnsRRset(testAMultiple, testZoneName).right.value
verifyMatch(result, testAMultiple) verifyMatch(result, testAMultiple)
} }
"convert AAAA record set" in { "convert AAAA record set" in {
val result = rightValue(toDnsRRset(testAAAA, testZoneName)) val result = toDnsRRset(testAAAA, testZoneName).right.value
verifyMatch(result, testAAAA) verifyMatch(result, testAAAA)
} }
"convert CNAME record set" in { "convert CNAME record set" in {
val result = rightValue(toDnsRRset(testCNAME, testZoneName)) val result = toDnsRRset(testCNAME, testZoneName).right.value
verifyMatch(result, testCNAME) verifyMatch(result, testCNAME)
} }
"convert DS record set" in { "convert DS record set" in {
val result = rightValue(toDnsRRset(testDS, testZoneName)) val result = toDnsRRset(testDS, testZoneName).right.value
verifyMatch(result, testDS) verifyMatch(result, testDS)
} }
"convert MX record set" in { "convert MX record set" in {
val result = rightValue(toDnsRRset(testMX, testZoneName)) val result = toDnsRRset(testMX, testZoneName).right.value
verifyMatch(result, testMX) verifyMatch(result, testMX)
} }
"convert NS record set" in { "convert NS record set" in {
val result = rightValue(toDnsRRset(testNS, testZoneName)) val result = toDnsRRset(testNS, testZoneName).right.value
verifyMatch(result, testNS) verifyMatch(result, testNS)
} }
"convert PTR record set" in { "convert PTR record set" in {
val result = rightValue(toDnsRRset(testPTR, testZoneName)) val result = toDnsRRset(testPTR, testZoneName).right.value
verifyMatch(result, testPTR) verifyMatch(result, testPTR)
} }
"convert SOA record set" in { "convert SOA record set" in {
val result = rightValue(toDnsRRset(testSOA, testZoneName)) val result = toDnsRRset(testSOA, testZoneName).right.value
verifyMatch(result, testSOA) verifyMatch(result, testSOA)
} }
"convert SPF record set" in { "convert SPF record set" in {
val result = rightValue(toDnsRRset(testSPF, testZoneName)) val result = toDnsRRset(testSPF, testZoneName).right.value
verifyMatch(result, testSPF) verifyMatch(result, testSPF)
} }
"convert SSHFP record set" in { "convert SSHFP record set" in {
val result = rightValue(toDnsRRset(testSSHFP, testZoneName)) val result = toDnsRRset(testSSHFP, testZoneName).right.value
verifyMatch(result, testSSHFP) verifyMatch(result, testSSHFP)
} }
"convert SRV record set" in { "convert SRV record set" in {
val result = rightValue(toDnsRRset(testSRV, testZoneName)) val result = toDnsRRset(testSRV, testZoneName).right.value
verifyMatch(result, testSRV) verifyMatch(result, testSRV)
} }
"convert NAPTR record set" in { "convert NAPTR record set" in {
val result = rightValue(toDnsRRset(testNAPTR, testZoneName)) val result = toDnsRRset(testNAPTR, testZoneName).right.value
verifyMatch(result, testNAPTR) verifyMatch(result, testNAPTR)
} }
"convert TXT record set" in { "convert TXT record set" in {
val result = rightValue(toDnsRRset(testTXT, testZoneName)) val result = toDnsRRset(testTXT, testZoneName).right.value
verifyMatch(result, testTXT) verifyMatch(result, testTXT)
} }
"convert long TXT record set" in { "convert long TXT record set" in {
val result = rightValue(toDnsRRset(testLongTXT, testZoneName)) val result = toDnsRRset(testLongTXT, testZoneName).right.value
verifyMatch(result, testLongTXT) verifyMatch(result, testLongTXT)
} }
"fail to convert a bad SPF record set" in { "fail to convert a bad SPF record set" in {
val result = leftValue(toDnsRRset(testLongSPF, testZoneName)) val result = toDnsRRset(testLongSPF, testZoneName).left.value
result shouldBe a[java.lang.IllegalArgumentException] result shouldBe a[java.lang.IllegalArgumentException]
} }
} }
@ -408,67 +407,67 @@ class DnsConversionsSpec
"Converting to a Dns Response" should { "Converting to a Dns Response" should {
"return the message when NoError" in { "return the message when NoError" in {
doReturn(DNS.Rcode.NOERROR).when(mockMessage).getRcode doReturn(DNS.Rcode.NOERROR).when(mockMessage).getRcode
rightValue(toDnsResponse(mockMessage)) shouldBe NoError(mockMessage) toDnsResponse(mockMessage).right.value shouldBe NoError(mockMessage)
} }
"return a BadKey" in { "return a BadKey" in {
doReturn(DNS.Rcode.BADKEY).when(mockMessage).getRcode doReturn(DNS.Rcode.BADKEY).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[BadKey] toDnsResponse(mockMessage).left.value shouldBe a[BadKey]
} }
"return a BadMode" in { "return a BadMode" in {
doReturn(DNS.Rcode.BADMODE).when(mockMessage).getRcode doReturn(DNS.Rcode.BADMODE).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[BadMode] toDnsResponse(mockMessage).left.value shouldBe a[BadMode]
} }
"return a BadSig" in { "return a BadSig" in {
doReturn(DNS.Rcode.BADSIG).when(mockMessage).getRcode doReturn(DNS.Rcode.BADSIG).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[BadSig] toDnsResponse(mockMessage).left.value shouldBe a[BadSig]
} }
"return a BadTime" in { "return a BadTime" in {
doReturn(DNS.Rcode.BADTIME).when(mockMessage).getRcode doReturn(DNS.Rcode.BADTIME).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[BadTime] toDnsResponse(mockMessage).left.value shouldBe a[BadTime]
} }
"return a FormatError" in { "return a FormatError" in {
doReturn(DNS.Rcode.FORMERR).when(mockMessage).getRcode doReturn(DNS.Rcode.FORMERR).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[FormatError] toDnsResponse(mockMessage).left.value shouldBe a[FormatError]
} }
"return a NotAuthorized" in { "return a NotAuthorized" in {
doReturn(DNS.Rcode.NOTAUTH).when(mockMessage).getRcode doReturn(DNS.Rcode.NOTAUTH).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[NotAuthorized] toDnsResponse(mockMessage).left.value shouldBe a[NotAuthorized]
} }
"return a NotImplemented" in { "return a NotImplemented" in {
doReturn(DNS.Rcode.NOTIMP).when(mockMessage).getRcode doReturn(DNS.Rcode.NOTIMP).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[NotImplemented] toDnsResponse(mockMessage).left.value shouldBe a[NotImplemented]
} }
"return a NotZone" in { "return a NotZone" in {
doReturn(DNS.Rcode.NOTZONE).when(mockMessage).getRcode doReturn(DNS.Rcode.NOTZONE).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[NotZone] toDnsResponse(mockMessage).left.value shouldBe a[NotZone]
} }
"return a NameNotFound" in { "return a NameNotFound" in {
doReturn(DNS.Rcode.NXDOMAIN).when(mockMessage).getRcode doReturn(DNS.Rcode.NXDOMAIN).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[NameNotFound] toDnsResponse(mockMessage).left.value shouldBe a[NameNotFound]
} }
"return a RecordSetNotFound" in { "return a RecordSetNotFound" in {
doReturn(DNS.Rcode.NXRRSET).when(mockMessage).getRcode doReturn(DNS.Rcode.NXRRSET).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[RecordSetNotFound] toDnsResponse(mockMessage).left.value shouldBe a[RecordSetNotFound]
} }
"return a Refused" in { "return a Refused" in {
doReturn(DNS.Rcode.REFUSED).when(mockMessage).getRcode doReturn(DNS.Rcode.REFUSED).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[Refused] toDnsResponse(mockMessage).left.value shouldBe a[Refused]
} }
"return a ServerFailure" in { "return a ServerFailure" in {
doReturn(DNS.Rcode.SERVFAIL).when(mockMessage).getRcode doReturn(DNS.Rcode.SERVFAIL).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[ServerFailure] toDnsResponse(mockMessage).left.value shouldBe a[ServerFailure]
} }
"return a NameExists" in { "return a NameExists" in {
doReturn(DNS.Rcode.YXDOMAIN).when(mockMessage).getRcode doReturn(DNS.Rcode.YXDOMAIN).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[NameExists] toDnsResponse(mockMessage).left.value shouldBe a[NameExists]
} }
"return a RecordSetExists" in { "return a RecordSetExists" in {
doReturn(DNS.Rcode.YXRRSET).when(mockMessage).getRcode doReturn(DNS.Rcode.YXRRSET).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[RecordSetExists] toDnsResponse(mockMessage).left.value shouldBe a[RecordSetExists]
} }
"return a UnrecognizedResponse" in { "return a UnrecognizedResponse" in {
doReturn(999).when(mockMessage).getRcode doReturn(999).when(mockMessage).getRcode
leftValue(toDnsResponse(mockMessage)) shouldBe a[UnrecognizedResponse] toDnsResponse(mockMessage).left.value shouldBe a[UnrecognizedResponse]
} }
} }
@ -572,7 +571,7 @@ class DnsConversionsSpec
"Converting to an update message" should { "Converting to an update message" should {
"work for an Add message" in { "work for an Add message" in {
val dnsMessage = rightValue(toAddRecordMessage(rrset(testDnsA), testZoneName)) val dnsMessage = toAddRecordMessage(rrset(testDnsA), testZoneName).right.value
val dnsRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0) val dnsRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0)
dnsRecord.getName.toString shouldBe "a-record." dnsRecord.getName.toString shouldBe "a-record."
dnsRecord.getTTL shouldBe testA.ttl dnsRecord.getTTL shouldBe testA.ttl
@ -585,7 +584,7 @@ class DnsConversionsSpec
} }
"work for an Update message" in { "work for an Update message" in {
val dnsMessage = val dnsMessage =
rightValue(toUpdateRecordMessage(rrset(testDnsA), rrset(testDnsAReplace), testZoneName)) toUpdateRecordMessage(rrset(testDnsA), rrset(testDnsAReplace), testZoneName).right.value
// Update record issues a replace, the first section is an EmptyRecord containing the name and type to replace // Update record issues a replace, the first section is an EmptyRecord containing the name and type to replace
val emptyRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0) val emptyRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0)
emptyRecord.getName.toString shouldBe "a-record-2." emptyRecord.getName.toString shouldBe "a-record-2."
@ -604,7 +603,7 @@ class DnsConversionsSpec
zoneRRset.getName.toString shouldBe "vinyldns." zoneRRset.getName.toString shouldBe "vinyldns."
} }
"work for a Delete message" in { "work for a Delete message" in {
val dnsMessage = rightValue(toDeleteRecordMessage(rrset(testDnsA), testZoneName)) val dnsMessage = toDeleteRecordMessage(rrset(testDnsA), testZoneName).right.value
val dnsRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0) val dnsRecord = dnsMessage.getSectionArray(DNS.Section.UPDATE)(0)
dnsRecord.getName.toString shouldBe "a-record." dnsRecord.getName.toString shouldBe "a-record."
@ -629,7 +628,7 @@ class DnsConversionsSpec
"convert zone name to @" in { "convert zone name to @" in {
val actual = toDnsRecords(testAt, testZoneName) val actual = toDnsRecords(testAt, testZoneName)
val omitFinalDot = false val omitFinalDot = false
rightValue(actual).head.getName.toString(omitFinalDot) shouldBe testZoneName actual.right.value.head.getName.toString(omitFinalDot) shouldBe testZoneName
} }
} }

View File

@ -24,10 +24,8 @@ import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.BeforeAndAfterEach import org.scalatest.BeforeAndAfterEach
import vinyldns.api.Interfaces._
import vinyldns.api.ResultHelpers import vinyldns.api.ResultHelpers
import vinyldns.api.domain.access.AccessValidations import vinyldns.api.domain.access.AccessValidations
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.record.RecordSetHelpers._ import vinyldns.api.domain.record.RecordSetHelpers._
import vinyldns.api.domain.zone._ import vinyldns.api.domain.zone._
import vinyldns.api.route.{ListGlobalRecordSetsResponse, ListRecordSetsByZoneResponse} import vinyldns.api.route.{ListGlobalRecordSetsResponse, ListRecordSetsByZoneResponse}
@ -36,6 +34,7 @@ import vinyldns.core.TestRecordSetData._
import vinyldns.core.TestZoneData._ import vinyldns.core.TestZoneData._
import vinyldns.core.domain.HighValueDomainError import vinyldns.core.domain.HighValueDomainError
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.membership.{GroupRepository, ListUsersResults, UserRepository} import vinyldns.core.domain.membership.{GroupRepository, ListUsersResults, UserRepository}
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone._ import vinyldns.core.domain.zone._
@ -55,12 +54,9 @@ class RecordSetServiceSpec
private val mockRecordChangeRepo = mock[RecordChangeRepository] private val mockRecordChangeRepo = mock[RecordChangeRepository]
private val mockUserRepo = mock[UserRepository] private val mockUserRepo = mock[UserRepository]
private val mockMessageQueue = mock[MessageQueue] private val mockMessageQueue = mock[MessageQueue]
private val zoneConnection = private val mockBackend =
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1") mock[Backend]
private val configuredDnsConnections = private val mockBackendResolver = mock[BackendResolver]
ConfiguredDnsConnections(zoneConnection, zoneConnection, List())
private val mockDnsConnection =
mock[DnsConnection]
doReturn(IO.pure(Some(okZone))).when(mockZoneRepo).getZone(okZone.id) doReturn(IO.pure(Some(okZone))).when(mockZoneRepo).getZone(okZone.id)
doReturn(IO.pure(Some(zoneNotAuthorized))) doReturn(IO.pure(Some(zoneNotAuthorized)))
@ -70,6 +66,7 @@ class RecordSetServiceSpec
doReturn(IO.pure(Some(sharedZoneRecord.copy(status = RecordSetStatus.Active)))) doReturn(IO.pure(Some(sharedZoneRecord.copy(status = RecordSetStatus.Active))))
.when(mockRecordRepo) .when(mockRecordRepo)
.getRecordSet(sharedZoneRecord.id) .getRecordSet(sharedZoneRecord.id)
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val underTest = new RecordSetService( val underTest = new RecordSetService(
mockZoneRepo, mockZoneRepo,
@ -79,8 +76,7 @@ class RecordSetServiceSpec
mockUserRepo, mockUserRepo,
mockMessageQueue, mockMessageQueue,
new AccessValidations(), new AccessValidations(),
(_, _) => mockDnsConnection, mockBackendResolver,
configuredDnsConnections,
false false
) )
@ -92,8 +88,7 @@ class RecordSetServiceSpec
mockUserRepo, mockUserRepo,
mockMessageQueue, mockMessageQueue,
new AccessValidations(), new AccessValidations(),
(_, _) => mockDnsConnection, mockBackendResolver,
configuredDnsConnections,
true true
) )
@ -140,8 +135,8 @@ class RecordSetServiceSpec
.when(mockRecordRepo) .when(mockRecordRepo)
.getRecordSets(okZone.id, record.name, record.typ) .getRecordSets(okZone.id, record.name, record.typ)
doReturn(IO(List(aaaa)).toResult) doReturn(IO(List(aaaa)))
.when(mockDnsConnection) .when(mockBackend)
.resolve(record.name, okZone.name, record.typ) .resolve(record.name, okZone.name, record.typ)
val result = leftResultOf(underTest.addRecordSet(aaaa, okAuth).value) val result = leftResultOf(underTest.addRecordSet(aaaa, okAuth).value)
@ -298,8 +293,8 @@ class RecordSetServiceSpec
doReturn(IO.pure(List(record))) doReturn(IO.pure(List(record)))
.when(mockRecordRepo) .when(mockRecordRepo)
.getRecordSets(okZone.id, record.name, record.typ) .getRecordSets(okZone.id, record.name, record.typ)
doReturn(IO(List()).toResult) doReturn(IO(List()))
.when(mockDnsConnection) .when(mockBackend)
.resolve(record.name, okZone.name, record.typ) .resolve(record.name, okZone.name, record.typ)
doReturn(IO.pure(List())) doReturn(IO.pure(List()))
.when(mockRecordRepo) .when(mockRecordRepo)
@ -653,8 +648,8 @@ class RecordSetServiceSpec
doReturn(IO.pure(List(newRecord))) doReturn(IO.pure(List(newRecord)))
.when(mockRecordRepo) .when(mockRecordRepo)
.getRecordSetsByName(okZone.id, newRecord.name) .getRecordSetsByName(okZone.id, newRecord.name)
doReturn(IO(List()).toResult) doReturn(IO(List()))
.when(mockDnsConnection) .when(mockBackend)
.resolve(newRecord.name, okZone.name, newRecord.typ) .resolve(newRecord.name, okZone.name, newRecord.typ)
val result: RecordSetChange = rightResultOf( val result: RecordSetChange = rightResultOf(

View File

@ -23,14 +23,13 @@ import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.BeforeAndAfterEach import org.scalatest.BeforeAndAfterEach
import vinyldns.api.Interfaces._
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.dns.DnsProtocol.TypeNotFound
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.api.ResultHelpers import vinyldns.api.ResultHelpers
import cats.effect._ import cats.effect._
import org.mockito.Matchers.any
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.zone.{ConfiguredDnsConnections, DnsBackend, Zone, ZoneConnection} import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.zone.{ConfiguredDnsConnections, LegacyDnsBackend, Zone, ZoneConnection}
import scala.concurrent.duration._ import scala.concurrent.duration._
@ -43,18 +42,12 @@ class ZoneConnectionValidatorSpec
with EitherMatchers with EitherMatchers
with EitherValues { with EitherValues {
private val mockDnsConnection = mock[DnsConnection] private val mockBackend = mock[Backend]
private val mockZoneView = mock[ZoneView] private val mockZoneView = mock[ZoneView]
private val mockBackendResolver = mock[BackendResolver]
override protected def beforeEach(): Unit = override protected def beforeEach(): Unit =
reset(mockDnsConnection, mockZoneView) reset(mockBackend, mockZoneView)
private def testDnsConnection(conn: ZoneConnection) =
if (conn.keyName == "error.") {
throw new RuntimeException("main connection failure!")
} else {
mockDnsConnection
}
private def testLoadDns(zone: Zone) = zone.name match { private def testLoadDns(zone: Zone) = zone.name match {
case "error." => IO.raiseError(new RuntimeException("transfer connection failure!")) case "error." => IO.raiseError(new RuntimeException("transfer connection failure!"))
@ -67,21 +60,14 @@ class ZoneConnectionValidatorSpec
IO.pure(mockZoneView) IO.pure(mockZoneView)
} }
private def testDefaultConnection: ZoneConnection =
ZoneConnection("name", "key-name", "key", "localhost:19001")
private def testConfiguredConnections: ConfiguredDnsConnections =
ConfiguredDnsConnections(testDefaultConnection, testDefaultConnection, List())
private def generateZoneView(zone: Zone, recordSets: RecordSet*): ZoneView = private def generateZoneView(zone: Zone, recordSets: RecordSet*): ZoneView =
ZoneView( ZoneView(
zone = zone, zone = zone,
recordSets = recordSets.toList recordSets = recordSets.toList
) )
class TestConnectionValidator() extends ZoneConnectionValidator(testConfiguredConnections) { class TestConnectionValidator() extends ZoneConnectionValidator(mockBackendResolver) {
override val opTimeout: FiniteDuration = 10.milliseconds override val opTimeout: FiniteDuration = 10.milliseconds
override def dnsConnection(conn: ZoneConnection): DnsConnection = testDnsConnection(conn)
override def loadDns(zone: Zone): IO[ZoneView] = testLoadDns(zone) override def loadDns(zone: Zone): IO[ZoneView] = testLoadDns(zone)
override def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit] = override def isValidBackendId(backendId: Option[String]): Either[Throwable, Unit] =
Right(()) Right(())
@ -142,11 +128,9 @@ class ZoneConnectionValidatorSpec
List(NSData(Fqdn("sub.some.test.ns."))) List(NSData(Fqdn("sub.some.test.ns.")))
) )
private val mockRecordSet = mock[RecordSet]
val zc = ZoneConnection("zc.", "zc.", "zc", "10.1.1.1") val zc = ZoneConnection("zc.", "zc.", "zc", "10.1.1.1")
val transfer = ZoneConnection("transfer.", "transfer.", "transfer", "10.1.1.1") val transfer = ZoneConnection("transfer.", "transfer.", "transfer", "10.1.1.1")
val backend = DnsBackend( val backend = LegacyDnsBackend(
"some-backend-id", "some-backend-id",
zc.copy(name = "backend-conn"), zc.copy(name = "backend-conn"),
transfer.copy(name = "backend-transfer") transfer.copy(name = "backend-transfer")
@ -159,9 +143,8 @@ class ZoneConnectionValidatorSpec
doReturn(generateZoneView(testZone, successSoa, successNS, delegatedNS).recordSetsMap) doReturn(generateZoneView(testZone, successSoa, successNS, delegatedNS).recordSetsMap)
.when(mockZoneView) .when(mockZoneView)
.recordSetsMap .recordSetsMap
doReturn(List(successSoa).toResult) doReturn(IO.pure(true)).when(mockBackend).zoneExists(any[Zone])
.when(mockDnsConnection) doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
.resolve(testZone.name, testZone.name, RecordType.SOA)
val result = awaitResultOf(underTest.validateZoneConnections(testZone).value) val result = awaitResultOf(underTest.validateZoneConnections(testZone).value)
result should be(right) result should be(right)
@ -172,9 +155,8 @@ class ZoneConnectionValidatorSpec
doReturn(generateZoneView(testZone, successSoa, failureNs, delegatedNS).recordSetsMap) doReturn(generateZoneView(testZone, successSoa, failureNs, delegatedNS).recordSetsMap)
.when(mockZoneView) .when(mockZoneView)
.recordSetsMap .recordSetsMap
doReturn(List(successSoa).toResult) doReturn(IO.pure(true)).when(mockBackend).zoneExists(any[Zone])
.when(mockDnsConnection) doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
.resolve(testZone.name, testZone.name, RecordType.SOA)
val result = leftResultOf(underTest.validateZoneConnections(testZone).value) val result = leftResultOf(underTest.validateZoneConnections(testZone).value)
result shouldBe ZoneValidationFailed( result shouldBe ZoneValidationFailed(
@ -187,9 +169,11 @@ class ZoneConnectionValidatorSpec
"respond with a failure if no records are returned from the backend" in { "respond with a failure if no records are returned from the backend" in {
doReturn(testZone).when(mockZoneView).zone doReturn(testZone).when(mockZoneView).zone
doReturn(generateZoneView(testZone).recordSetsMap).when(mockZoneView).recordSetsMap doReturn(generateZoneView(testZone).recordSetsMap).when(mockZoneView).recordSetsMap
doReturn(List.empty[RecordSet].toResult) doReturn(IO.pure(List.empty[RecordSet]))
.when(mockDnsConnection) .when(mockBackend)
.resolve(testZone.name, testZone.name, RecordType.SOA) .resolve(testZone.name, testZone.name, RecordType.SOA)
doReturn(IO.pure(true)).when(mockBackend).zoneExists(any[Zone])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val result = leftResultOf(underTest.validateZoneConnections(testZone).value) val result = leftResultOf(underTest.validateZoneConnections(testZone).value)
result shouldBe a[ZoneValidationFailed] result shouldBe a[ZoneValidationFailed]
@ -200,28 +184,20 @@ class ZoneConnectionValidatorSpec
) )
} }
"respond with a failure if any failure is returned from the backend" in {
doReturn(result(TypeNotFound("fail")))
.when(mockDnsConnection)
.resolve(testZone.name, testZone.name, RecordType.SOA)
val error = leftResultOf(underTest.validateZoneConnections(testZone).value)
error shouldBe ConnectionFailed(testZone, s"Unable to connect to zone: fail")
}
"respond with a failure if connection cant be made" in { "respond with a failure if connection cant be made" in {
val badZone = Zone( val badZone = Zone(
"vinyldns.", "error.",
"test@test.com", "test@test.com",
connection = connection =
Some(ZoneConnection("error.", "error.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")), Some(ZoneConnection("error.", "error.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")),
transferConnection = transferConnection =
Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")) Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1"))
) )
doReturn(IO.pure(true)).when(mockBackend).zoneExists(any[Zone])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val result = leftResultOf(underTest.validateZoneConnections(badZone).value) val result = leftResultOf(underTest.validateZoneConnections(badZone).value)
result shouldBe a[ConnectionFailed] result shouldBe a[ConnectionFailed]
result.getMessage should include("main connection failure!")
} }
"respond with a failure if loadDns throws an error" in { "respond with a failure if loadDns throws an error" in {
@ -234,9 +210,8 @@ class ZoneConnectionValidatorSpec
Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")) Some(ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1"))
) )
doReturn(List(mockRecordSet).toResult) doReturn(IO.pure(true)).when(mockBackend).zoneExists(any[Zone])
.when(mockDnsConnection) doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
.resolve(badZone.name, badZone.name, RecordType.SOA)
val result = leftResultOf(underTest.validateZoneConnections(badZone).value) val result = leftResultOf(underTest.validateZoneConnections(badZone).value)
result shouldBe a[ConnectionFailed] result shouldBe a[ConnectionFailed]
@ -244,11 +219,11 @@ class ZoneConnectionValidatorSpec
} }
"isValidBackendId" should { "isValidBackendId" should {
val backend = DnsBackend("some-test-backend", testDefaultConnection, testDefaultConnection) doReturn(true).when(mockBackendResolver).isRegistered("some-test-backend")
doReturn(false).when(mockBackendResolver).isRegistered("bad")
val underTest = val underTest =
new ZoneConnectionValidator( new ZoneConnectionValidator(mockBackendResolver)
ConfiguredDnsConnections(testDefaultConnection, testDefaultConnection, List(backend))
)
"return success if the backendId exists" in { "return success if the backendId exists" in {
underTest.isValidBackendId(Some("some-test-backend")) shouldBe right underTest.isValidBackendId(Some("some-test-backend")) shouldBe right
@ -260,43 +235,5 @@ class ZoneConnectionValidatorSpec
underTest.isValidBackendId(Some("bad")) shouldBe left underTest.isValidBackendId(Some("bad")) shouldBe left
} }
} }
"getZoneConnection" should {
"get the specified zone connection if provided" in {
// both backendId and connection info specified; prefer connection info
val zone = testZone.copy(backendId = Some("some-backend-id"))
ZoneConnectionValidator.getZoneConnection(zone, connections) shouldBe zone.connection.get
}
"get a zone connection by backendID" in {
val zone = Zone("name.", "email", backendId = Some("some-backend-id"))
ZoneConnectionValidator.getZoneConnection(zone, connections) shouldBe backend.zoneConnection
}
"fall to default without connection info" in {
val zone = Zone("name.", "email")
ZoneConnectionValidator.getZoneConnection(zone, connections) shouldBe zc
}
"fall to default with an invalid backendId" in {
val zone = Zone("name.", "email", backendId = Some("bad-id"))
ZoneConnectionValidator.getZoneConnection(zone, connections) shouldBe zc
}
}
"getTransferConnection" should {
"get the specified transfer connection if provided" in {
// both backendId and connection info specified; prefer connection info
val zone = testZone.copy(backendId = Some("some-backend-id"))
ZoneConnectionValidator.getTransferConnection(zone, connections) shouldBe zone.transferConnection.get
}
"get a transfer connection by backendID" in {
val zone = Zone("name.", "email", backendId = Some("some-backend-id"))
ZoneConnectionValidator.getTransferConnection(zone, connections) shouldBe backend.transferConnection
}
"fall to default without connection info" in {
val zone = Zone("name.", "email")
ZoneConnectionValidator.getTransferConnection(zone, connections) shouldBe transfer
}
"fall to default with an invalid backendId" in {
val zone = Zone("name.", "email", backendId = Some("bad-id"))
ZoneConnectionValidator.getTransferConnection(zone, connections) shouldBe transfer
}
}
} }
} }

View File

@ -34,6 +34,7 @@ import vinyldns.core.domain.zone._
import vinyldns.core.queue.MessageQueue import vinyldns.core.queue.MessageQueue
import vinyldns.core.TestMembershipData._ import vinyldns.core.TestMembershipData._
import vinyldns.core.TestZoneData._ import vinyldns.core.TestZoneData._
import vinyldns.core.domain.backend.BackendResolver
import scala.concurrent.duration._ import scala.concurrent.duration._
@ -50,6 +51,7 @@ class ZoneServiceSpec
private val mockUserRepo = mock[UserRepository] private val mockUserRepo = mock[UserRepository]
private val mockZoneChangeRepo = mock[ZoneChangeRepository] private val mockZoneChangeRepo = mock[ZoneChangeRepository]
private val mockMessageQueue = mock[MessageQueue] private val mockMessageQueue = mock[MessageQueue]
private val mockBackendResolver = mock[BackendResolver]
private val badConnection = ZoneConnection("bad", "bad", "bad", "bad") private val badConnection = ZoneConnection("bad", "bad", "bad", "bad")
private val abcZoneSummary = ZoneSummaryInfo(abcZone, abcGroup.name, AccessLevel.Delete) private val abcZoneSummary = ZoneSummaryInfo(abcZone, abcGroup.name, AccessLevel.Delete)
private val xyzZoneSummary = ZoneSummaryInfo(xyzZone, xyzGroup.name, AccessLevel.NoAccess) private val xyzZoneSummary = ZoneSummaryInfo(xyzZone, xyzGroup.name, AccessLevel.NoAccess)
@ -76,7 +78,8 @@ class ZoneServiceSpec
TestConnectionValidator, TestConnectionValidator,
mockMessageQueue, mockMessageQueue,
new ZoneValidations(1000), new ZoneValidations(1000),
new AccessValidations() new AccessValidations(),
mockBackendResolver
) )
private val createZoneAuthorized = CreateZoneInput( private val createZoneAuthorized = CreateZoneInput(

View File

@ -26,25 +26,29 @@ import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.xbill.DNS import org.xbill.DNS
import org.xbill.DNS.{Name, ZoneTransferIn} import org.xbill.DNS.Name
import vinyldns.api.domain.dns.DnsConversions
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import scala.collection.JavaConverters._
import scala.collection.mutable import scala.collection.mutable
import cats.effect._ import cats.effect._
import vinyldns.api.backend.dns.DnsConversions
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.record.NameSort.NameSort import vinyldns.core.domain.record.NameSort.NameSort
import vinyldns.core.domain.record.RecordType.RecordType import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.zone.{Zone, ZoneConnection, ZoneStatus} import vinyldns.core.domain.zone.{Zone, ZoneConnection, ZoneStatus}
class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar with DnsConversions { class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar with DnsConversions {
val testZoneName = "vinyldns."
val testZoneConnection: Option[ZoneConnection] = Some( private val testZoneName = "vinyldns."
private val testZoneConnection: Option[ZoneConnection] = Some(
ZoneConnection(testZoneName, testZoneName, "nzisn+4G2ldMn0q1CV3vsg==", "127.0.0.1:19001") ZoneConnection(testZoneName, testZoneName, "nzisn+4G2ldMn0q1CV3vsg==", "127.0.0.1:19001")
) )
private val mockBackendResolver = mock[BackendResolver]
private val mockBackend = mock[Backend]
private val testZone = Zone("vinyldns.", "test@test.com") private val testZone = Zone("vinyldns.", "test@test.com")
private val records = List( private val records = List(
RecordSet( RecordSet(
@ -121,8 +125,6 @@ class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar wit
transferConnection = testZoneConnection transferConnection = testZoneConnection
) )
val mockTransfer = mock[ZoneTransferIn]
val expectedRecords = List( val expectedRecords = List(
RecordSet( RecordSet(
zoneId = testZone.id, zoneId = testZone.id,
@ -196,12 +198,10 @@ class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar wit
) )
) )
doReturn(dnsRecords.asJava).when(mockTransfer).getAXFR doReturn(IO.pure(expectedRecords)).when(mockBackend).loadZone(any[Zone], any[Int])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val mockTransferFunc = mock[Zone => ZoneTransferIn] val underTest = DnsZoneViewLoader(testZone, mockBackend, 1000)
doReturn(mockTransfer).when(mockTransferFunc).apply(testZone)
val underTest = DnsZoneViewLoader(testZone, mockTransferFunc)
val actual = underTest.load().unsafeToFuture() val actual = underTest.load().unsafeToFuture()
@ -235,8 +235,6 @@ class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar wit
transferConnection = testZoneConnection transferConnection = testZoneConnection
) )
val mockTransfer = mock[ZoneTransferIn]
val expectedRecords = List( val expectedRecords = List(
RecordSet( RecordSet(
zoneId = testZone.id, zoneId = testZone.id,
@ -346,12 +344,10 @@ class ZoneViewLoaderSpec extends AnyWordSpec with Matchers with MockitoSugar wit
) )
) )
doReturn(dnsRecords.asJava).when(mockTransfer).getAXFR doReturn(IO.pure(expectedRecords)).when(mockBackend).loadZone(any[Zone], any[Int])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val mockTransferFunc = mock[Zone => ZoneTransferIn] val underTest = DnsZoneViewLoader(testZone, mockBackend, 1000)
doReturn(mockTransfer).when(mockTransferFunc).apply(testZone)
val underTest = DnsZoneViewLoader(testZone, mockTransferFunc)
val actual = underTest.load().unsafeToFuture() val actual = underTest.load().unsafeToFuture()

View File

@ -24,13 +24,11 @@ import org.mockito.Mockito._
import org.scalatestplus.mockito.MockitoSugar import org.scalatestplus.mockito.MockitoSugar
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.BeforeAndAfterEach import org.scalatest.{BeforeAndAfterEach, EitherValues}
import org.xbill.DNS import vinyldns.api.backend.dns.DnsProtocol.{NotAuthorized, TryAgain}
import vinyldns.api.domain.dns.DnsConnection
import vinyldns.api.domain.dns.DnsProtocol.{NoError, NotAuthorized, Refused, TryAgain}
import vinyldns.api.engine.RecordSetChangeHandler.{AlreadyApplied, ReadyToApply, Requeue} import vinyldns.api.engine.RecordSetChangeHandler.{AlreadyApplied, ReadyToApply, Requeue}
import vinyldns.api.repository.InMemoryBatchChangeRepository import vinyldns.api.repository.InMemoryBatchChangeRepository
import vinyldns.api.{CatsHelpers, Interfaces} import vinyldns.api.CatsHelpers
import vinyldns.core.domain.batch.{ import vinyldns.core.domain.batch.{
BatchChange, BatchChange,
BatchChangeApprovalStatus, BatchChangeApprovalStatus,
@ -43,19 +41,20 @@ import vinyldns.core.TestRecordSetData._
import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext
import cats.effect.ContextShift import cats.effect.ContextShift
import vinyldns.core.domain.backend.{Backend, BackendResponse}
class RecordSetChangeHandlerSpec class RecordSetChangeHandlerSpec
extends AnyWordSpec extends AnyWordSpec
with Matchers with Matchers
with MockitoSugar with MockitoSugar
with BeforeAndAfterEach with BeforeAndAfterEach
with CatsHelpers { with CatsHelpers
with EitherValues {
private implicit val timer: Timer[IO] = IO.timer(ExecutionContext.global) private implicit val timer: Timer[IO] = IO.timer(ExecutionContext.global)
private val mockConn = mock[DnsConnection] private val mockBackend = mock[Backend]
private val mockRsRepo = mock[RecordSetRepository] private val mockRsRepo = mock[RecordSetRepository]
private val mockChangeRepo = mock[RecordChangeRepository] private val mockChangeRepo = mock[RecordChangeRepository]
private val mockDnsMessage = mock[DNS.Message]
private val rsRepoCaptor = ArgumentCaptor.forClass(classOf[ChangeSet]) private val rsRepoCaptor = ArgumentCaptor.forClass(classOf[ChangeSet])
private val changeRepoCaptor = ArgumentCaptor.forClass(classOf[ChangeSet]) private val changeRepoCaptor = ArgumentCaptor.forClass(classOf[ChangeSet])
@ -112,7 +111,7 @@ class RecordSetChangeHandlerSpec
RecordSetChangeHandler(mockRsRepo, mockChangeRepo, batchRepo) RecordSetChangeHandler(mockRsRepo, mockChangeRepo, batchRepo)
override protected def beforeEach(): Unit = { override protected def beforeEach(): Unit = {
reset(mockConn, mockRsRepo, mockChangeRepo) reset(mockBackend, mockRsRepo, mockChangeRepo)
batchRepo.clear() batchRepo.clear()
// seed the linked batch change in the DB // seed the linked batch change in the DB
@ -126,14 +125,14 @@ class RecordSetChangeHandlerSpec
"Handling Pending Changes" should { "Handling Pending Changes" should {
"complete the change successfully if already applied" in { "complete the change successfully if already applied" in {
doReturn(Interfaces.result(List(rs))) doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -162,17 +161,17 @@ class RecordSetChangeHandlerSpec
"apply the change if not yet applied" in { "apply the change if not yet applied" in {
// The second return is for verify // The second return is for verify
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(List(rs))) .doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(NoError(mockDnsMessage))).when(mockConn).applyChange(rsChange) doReturn(IO.pure(BackendResponse.NoError("test"))).when(mockBackend).applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -188,8 +187,8 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete
// make sure the record was applied and then verified // make sure the record was applied and then verified
verify(mockConn).applyChange(rsChange) verify(mockBackend).applyChange(rsChange)
verify(mockConn, times(2)).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, times(2)).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -205,19 +204,19 @@ class RecordSetChangeHandlerSpec
"bypass verify and fail if the dns update fails" in { "bypass verify and fail if the dns update fails" in {
// The second return is for verify // The second return is for verify
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(List(rs))) .doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(Left(NotAuthorized("dns failure")))) doReturn(IO.raiseError(NotAuthorized("dns failure")))
.when(mockConn) .when(mockBackend)
.applyChange(rsChange) .applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -234,10 +233,10 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed
// make sure the record was applied // make sure the record was applied
verify(mockConn).applyChange(rsChange) verify(mockBackend).applyChange(rsChange)
// make sure we only called resolve once when validating, ensures that verify was not called // make sure we only called resolve once when validating, ensures that verify was not called
verify(mockConn, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -253,17 +252,17 @@ class RecordSetChangeHandlerSpec
"fail the change in verify if verify errors" in { "fail the change in verify if verify errors" in {
// All returns after first are for verify. Retry 2 times and succeed // All returns after first are for verify. Retry 2 times and succeed
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(Left(NotAuthorized("dns-fail")))) .doReturn(IO.raiseError(NotAuthorized("dns-fail")))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(NoError(mockDnsMessage))).when(mockConn).applyChange(rsChange) doReturn(IO.pure(BackendResponse.NoError("test"))).when(mockBackend).applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -279,10 +278,10 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed
// make sure the record was applied and then verified // make sure the record was applied and then verified
verify(mockConn).applyChange(rsChange) verify(mockBackend).applyChange(rsChange)
// we will retry the verify 3 times based on the mock setup // we will retry the verify 3 times based on the mock setup
verify(mockConn, times(2)).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, times(2)).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -297,30 +296,30 @@ class RecordSetChangeHandlerSpec
} }
"requeue the change in verify if permissible errors" in { "requeue the change in verify if permissible errors" in {
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(Left(TryAgain("dns-fail")))) .doReturn(IO.raiseError(TryAgain("dns-fail")))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(NoError(mockDnsMessage))).when(mockConn).applyChange(rsChange) doReturn(IO.pure(BackendResponse.NoError("test"))).when(mockBackend).applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
a[Requeue] shouldBe thrownBy(test.unsafeRunSync()) a[Requeue] shouldBe thrownBy(test.unsafeRunSync())
} }
"fail the change if validating fails with an error" in { "fail the change if validating fails with an error" in {
// Stage an error on the first resolve, which will cause validate to fail // Stage an error on the first resolve, which will cause validate to fail
doReturn(Interfaces.result(Left(NotAuthorized("dns-failure")))) doReturn(IO.raiseError(NotAuthorized("dns-failure")))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -336,8 +335,8 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed
// we failed in validation, so we should never issue a dns update // we failed in validation, so we should never issue a dns update
verify(mockConn, never()).applyChange(rsChange) verify(mockBackend, never()).applyChange(rsChange)
verify(mockConn, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -352,17 +351,17 @@ class RecordSetChangeHandlerSpec
} }
"fail the change if applying fails with an error" in { "fail the change if applying fails with an error" in {
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(Left(NotAuthorized("dns-fail")))) doReturn(IO.raiseError(NotAuthorized("dns-fail")))
.when(mockConn) .when(mockBackend)
.applyChange(rsChange) .applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -378,8 +377,8 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed savedCs.changes.head.status shouldBe RecordSetChangeStatus.Failed
// we failed in apply, we should only resolve once // we failed in apply, we should only resolve once
verify(mockConn, times(1)).applyChange(rsChange) verify(mockBackend, times(1)).applyChange(rsChange)
verify(mockConn, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, times(1)).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -393,21 +392,6 @@ class RecordSetChangeHandlerSpec
batchChangeUpdates.get.changes shouldBe scExpected batchChangeUpdates.get.changes shouldBe scExpected
} }
"requeue the change in apply if permissible errors" in {
doReturn(Interfaces.result(List()))
.when(mockConn)
.resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(Left(Refused("dns-fail"))))
.when(mockConn)
.applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, rsChange)
a[Requeue] shouldBe thrownBy(test.unsafeRunSync())
}
"bypass the validate and verify steps if a wildcard record exists" in { "bypass the validate and verify steps if a wildcard record exists" in {
// Return a wildcard record // Return a wildcard record
doReturn(IO.pure(List(rsChange.recordSet))) doReturn(IO.pure(List(rsChange.recordSet)))
@ -415,18 +399,18 @@ class RecordSetChangeHandlerSpec
.getRecordSets(anyString, anyString, any(classOf[RecordType])) .getRecordSets(anyString, anyString, any(classOf[RecordType]))
// The second return is for verify // The second return is for verify
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(List(rs))) .doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(Right(NoError(mockDnsMessage)))) doReturn(IO.pure(BackendResponse.NoError("test")))
.when(mockConn) .when(mockBackend)
.applyChange(rsChange) .applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
val res = test.unsafeRunSync() val res = test.unsafeRunSync()
res.status shouldBe RecordSetChangeStatus.Complete res.status shouldBe RecordSetChangeStatus.Complete
@ -445,10 +429,10 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete
// make sure the record was applied // make sure the record was applied
verify(mockConn).applyChange(rsChange) verify(mockBackend).applyChange(rsChange)
// make sure we never called resolve, as we skip validate step and verify // make sure we never called resolve, as we skip validate step and verify
verify(mockConn, never).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, never).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id)) val batchChangeUpdates = await(batchRepo.getBatchChange(batchChange.id))
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -474,18 +458,18 @@ class RecordSetChangeHandlerSpec
.getRecordSets(rsChange.recordSet.zoneId, "*", RecordType.CNAME) .getRecordSets(rsChange.recordSet.zoneId, "*", RecordType.CNAME)
// The second return is for verify // The second return is for verify
doReturn(Interfaces.result(List())) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(List(rs))) .doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(Interfaces.result(Right(NoError(mockDnsMessage)))) doReturn(IO.pure(BackendResponse.NoError("test")))
.when(mockConn) .when(mockBackend)
.applyChange(rsChange) .applyChange(rsChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
val test = underTest.apply(mockConn, rsChange) val test = underTest.apply(mockBackend, rsChange)
val res = test.unsafeRunSync() val res = test.unsafeRunSync()
res.status shouldBe RecordSetChangeStatus.Complete res.status shouldBe RecordSetChangeStatus.Complete
@ -504,10 +488,10 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete
// make sure the record was applied // make sure the record was applied
verify(mockConn).applyChange(rsChange) verify(mockBackend).applyChange(rsChange)
// make sure we never called resolve, as we skip validate step and verify // make sure we never called resolve, as we skip validate step and verify
verify(mockConn, never).resolve(rs.name, rsChange.zone.name, rs.typ) verify(mockBackend, never).resolve(rs.name, rsChange.zone.name, rs.typ)
val batchChangeUpdates = batchRepo.getBatchChange(batchChange.id).unsafeRunSync() val batchChangeUpdates = batchRepo.getBatchChange(batchChange.id).unsafeRunSync()
val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch => val updatedSingleChanges = completeCreateAAAASingleChanges.map { ch =>
@ -532,18 +516,18 @@ class RecordSetChangeHandlerSpec
.getRecordSets(anyString, anyString, any(classOf[RecordType])) .getRecordSets(anyString, anyString, any(classOf[RecordType]))
// The second return is for verify // The second return is for verify
doReturn(Interfaces.result(Right(List()))) doReturn(IO.pure(List()))
.doReturn(Interfaces.result(Right(List(rsNs)))) .doReturn(IO.pure(List(rsNs)))
.when(mockConn) .when(mockBackend)
.resolve(rsNs.name, rsChangeNs.zone.name, rsNs.typ) .resolve(rsNs.name, rsChangeNs.zone.name, rsNs.typ)
doReturn(Interfaces.result(Right(NoError(mockDnsMessage)))) doReturn(IO.pure(BackendResponse.NoError("test")))
.when(mockConn) .when(mockBackend)
.applyChange(rsChangeNs) .applyChange(rsChangeNs)
doReturn(IO.pure(csNs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(csNs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(csNs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(csNs)).when(mockRsRepo).apply(any[ChangeSet])
val test = underTest.apply(mockConn, rsChangeNs) val test = underTest.apply(mockBackend, rsChangeNs)
val res = test.unsafeRunSync() val res = test.unsafeRunSync()
res.status shouldBe RecordSetChangeStatus.Complete res.status shouldBe RecordSetChangeStatus.Complete
@ -562,10 +546,10 @@ class RecordSetChangeHandlerSpec
savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete savedCs.changes.head.status shouldBe RecordSetChangeStatus.Complete
// make sure the record was applied // make sure the record was applied
verify(mockConn).applyChange(rsChangeNs) verify(mockBackend).applyChange(rsChangeNs)
// make sure we never called resolve, as we skip validate step and verify // make sure we never called resolve, as we skip validate step and verify
verify(mockConn, never).resolve(rsNs.name, rsChangeNs.zone.name, rsNs.typ) verify(mockBackend, never).resolve(rsNs.name, rsChangeNs.zone.name, rsNs.typ)
} }
"complete an update successfully if the requested record set change matches the DNS backend" in { "complete an update successfully if the requested record set change matches the DNS backend" in {
@ -573,11 +557,11 @@ class RecordSetChangeHandlerSpec
changeType = RecordSetChangeType.Update, changeType = RecordSetChangeType.Update,
updates = Some(rsChange.recordSet.copy(ttl = 87)) updates = Some(rsChange.recordSet.copy(ttl = 87))
) )
doReturn(Interfaces.result(Right(List(updateChange.recordSet)))) doReturn(IO.pure(List(updateChange.recordSet)))
.when(mockConn) .when(mockBackend)
.resolve(rsChange.recordSet.name, rsChange.zone.name, rsChange.recordSet.typ) .resolve(rsChange.recordSet.name, rsChange.zone.name, rsChange.recordSet.typ)
doReturn(Interfaces.result(Right(NoError(mockDnsMessage)))) doReturn(IO.pure(BackendResponse.NoError("test")))
.when(mockConn) .when(mockBackend)
.applyChange(updateChange) .applyChange(updateChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
@ -585,7 +569,7 @@ class RecordSetChangeHandlerSpec
.when(mockRsRepo) .when(mockRsRepo)
.getRecordSetsByName(cs.zoneId, rs.name) .getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, updateChange) val test = underTest.apply(mockBackend, updateChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -618,17 +602,17 @@ class RecordSetChangeHandlerSpec
updates = Some(rsChange.recordSet.copy(ttl = 87)) updates = Some(rsChange.recordSet.copy(ttl = 87))
) )
val dnsBackendRs = updateChange.recordSet.copy(ttl = 30) val dnsBackendRs = updateChange.recordSet.copy(ttl = 30)
doReturn(Interfaces.result(Right(List(dnsBackendRs)))) doReturn(IO.pure(List(dnsBackendRs)))
.when(mockConn) .when(mockBackend)
.resolve(rsChange.recordSet.name, rsChange.zone.name, rsChange.recordSet.typ) .resolve(rsChange.recordSet.name, rsChange.zone.name, rsChange.recordSet.typ)
doReturn(Interfaces.result(Right(NoError(mockDnsMessage)))) doReturn(IO.pure(BackendResponse.NoError("test")))
.when(mockConn) .when(mockBackend)
.applyChange(updateChange) .applyChange(updateChange)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockRsRepo).apply(any[ChangeSet])
doReturn(IO.pure(List(dnsBackendRs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List(dnsBackendRs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val test = underTest.apply(mockConn, updateChange) val test = underTest.apply(mockBackend, updateChange)
test.unsafeRunSync() test.unsafeRunSync()
verify(mockRsRepo).apply(rsRepoCaptor.capture()) verify(mockRsRepo).apply(rsRepoCaptor.capture())
@ -654,8 +638,8 @@ class RecordSetChangeHandlerSpec
"getProcessingStatus for Create" should { "getProcessingStatus for Create" should {
"return ReadyToApply if there are no records in the DNS backend" in { "return ReadyToApply if there are no records in the DNS backend" in {
doReturn(Interfaces.result(Right(List()))) doReturn(IO.pure(List()))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
@ -663,7 +647,7 @@ class RecordSetChangeHandlerSpec
RecordSetChangeHandler RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange, rsChange,
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -673,8 +657,8 @@ class RecordSetChangeHandlerSpec
} }
"return AlreadyApplied if the change already exists in the DNS backend" in { "return AlreadyApplied if the change already exists in the DNS backend" in {
doReturn(Interfaces.result(Right(List(rs)))) doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
@ -682,7 +666,7 @@ class RecordSetChangeHandlerSpec
RecordSetChangeHandler RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange, rsChange,
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -692,8 +676,8 @@ class RecordSetChangeHandlerSpec
} }
"remove record from database for Add if record does not exist in DNS backend" in { "remove record from database for Add if record does not exist in DNS backend" in {
doReturn(Interfaces.result(Right(List()))) doReturn(IO.pure(List()))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
@ -704,7 +688,7 @@ class RecordSetChangeHandlerSpec
RecordSetChangeHandler RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange, rsChange,
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -723,8 +707,8 @@ class RecordSetChangeHandlerSpec
val storedRs = rs.copy(ttl = 300) val storedRs = rs.copy(ttl = 300)
val syncedRsChange = val syncedRsChange =
rsChange.copy(changeType = RecordSetChangeType.Update, updates = Some(storedRs)) rsChange.copy(changeType = RecordSetChangeType.Update, updates = Some(storedRs))
doReturn(Interfaces.result(Right(List(syncedRsChange.updates.get)))) doReturn(IO.pure(List(syncedRsChange.updates.get)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List(storedRs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List(storedRs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
@ -732,7 +716,7 @@ class RecordSetChangeHandlerSpec
RecordSetChangeHandler RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
syncedRsChange, syncedRsChange,
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -742,8 +726,8 @@ class RecordSetChangeHandlerSpec
} }
"return ReadyToApply if current record set doesn't match DNS backend and DNS backend has no records" in { "return ReadyToApply if current record set doesn't match DNS backend and DNS backend has no records" in {
doReturn(Interfaces.result(Right(List()))) doReturn(IO.pure(List()))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
@ -751,7 +735,7 @@ class RecordSetChangeHandlerSpec
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange rsChange
.copy(changeType = RecordSetChangeType.Update, updates = Some(rs.copy(ttl = 300))), .copy(changeType = RecordSetChangeType.Update, updates = Some(rs.copy(ttl = 300))),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -761,8 +745,8 @@ class RecordSetChangeHandlerSpec
} }
"return AlreadyApplied if the change already exists in the DNS backend" in { "return AlreadyApplied if the change already exists in the DNS backend" in {
doReturn(Interfaces.result(Right(List(rsChange.recordSet)))) doReturn(IO.pure(List(rsChange.recordSet)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List(rsChange.recordSet))) doReturn(IO.pure(List(rsChange.recordSet)))
.when(mockRsRepo) .when(mockRsRepo)
@ -771,7 +755,7 @@ class RecordSetChangeHandlerSpec
val processorStatus = RecordSetChangeHandler val processorStatus = RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange.copy(changeType = RecordSetChangeType.Update), rsChange.copy(changeType = RecordSetChangeType.Update),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -781,8 +765,8 @@ class RecordSetChangeHandlerSpec
} }
"sync in the DNS backend for update if record does not exist in database" in { "sync in the DNS backend for update if record does not exist in database" in {
doReturn(Interfaces.result(Right(List(rs.copy(ttl = 100))))) doReturn(IO.pure(List(rs.copy(ttl = 100))))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
@ -796,7 +780,7 @@ class RecordSetChangeHandlerSpec
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange rsChange
.copy(changeType = RecordSetChangeType.Update, updates = Some(rs.copy(ttl = 100))), .copy(changeType = RecordSetChangeType.Update, updates = Some(rs.copy(ttl = 100))),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -812,15 +796,15 @@ class RecordSetChangeHandlerSpec
"getProcessingStatus for Delete" should { "getProcessingStatus for Delete" should {
"return ReadyToApply if there are records in the DNS backend" in { "return ReadyToApply if there are records in the DNS backend" in {
doReturn(Interfaces.result(Right(List(rs)))) doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List(rs))).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val processorStatus = RecordSetChangeHandler val processorStatus = RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange.copy(changeType = RecordSetChangeType.Delete), rsChange.copy(changeType = RecordSetChangeType.Delete),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -830,15 +814,15 @@ class RecordSetChangeHandlerSpec
} }
"return AlreadyApplied if there are no records in the DNS backend" in { "return AlreadyApplied if there are no records in the DNS backend" in {
doReturn(Interfaces.result(Right(List()))) doReturn(IO.pure(List()))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name) doReturn(IO.pure(List.empty)).when(mockRsRepo).getRecordSetsByName(cs.zoneId, rs.name)
val processorStatus = RecordSetChangeHandler val processorStatus = RecordSetChangeHandler
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange.copy(changeType = RecordSetChangeType.Delete), rsChange.copy(changeType = RecordSetChangeType.Delete),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true
@ -848,8 +832,8 @@ class RecordSetChangeHandlerSpec
} }
"sync in the DNS backend for Delete change if record exists" in { "sync in the DNS backend for Delete change if record exists" in {
doReturn(Interfaces.result(Right(List(rs)))) doReturn(IO.pure(List(rs)))
.when(mockConn) .when(mockBackend)
.resolve(rs.name, rsChange.zone.name, rs.typ) .resolve(rs.name, rsChange.zone.name, rs.typ)
doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet]) doReturn(IO.pure(cs)).when(mockChangeRepo).save(any[ChangeSet])
@ -863,7 +847,7 @@ class RecordSetChangeHandlerSpec
.syncAndGetProcessingStatusFromDnsBackend( .syncAndGetProcessingStatusFromDnsBackend(
rsChange rsChange
.copy(changeType = RecordSetChangeType.Delete), .copy(changeType = RecordSetChangeType.Delete),
mockConn, mockBackend,
mockRsRepo, mockRsRepo,
mockChangeRepo, mockChangeRepo,
true true

View File

@ -29,6 +29,7 @@ import vinyldns.api.VinylDNSTestHelpers
import vinyldns.api.domain.record.RecordSetChangeGenerator import vinyldns.api.domain.record.RecordSetChangeGenerator
import vinyldns.api.domain.zone.{DnsZoneViewLoader, VinylDNSZoneViewLoader, ZoneView} import vinyldns.api.domain.zone.{DnsZoneViewLoader, VinylDNSZoneViewLoader, ZoneView}
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.backend.{Backend, BackendResolver}
import vinyldns.core.domain.record.NameSort.NameSort import vinyldns.core.domain.record.NameSort.NameSort
import vinyldns.core.domain.record.RecordType.RecordType import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
@ -42,6 +43,8 @@ class ZoneSyncHandlerSpec
with BeforeAndAfterEach with BeforeAndAfterEach
with VinylDNSTestHelpers { with VinylDNSTestHelpers {
private val mockBackend = mock[Backend]
private val mockBackendResolver = mock[BackendResolver]
private val mockDNSLoader = mock[DnsZoneViewLoader] private val mockDNSLoader = mock[DnsZoneViewLoader]
private val mockVinylDNSLoader = mock[VinylDNSZoneViewLoader] private val mockVinylDNSLoader = mock[VinylDNSZoneViewLoader]
private val recordSetRepo = mock[RecordSetRepository] private val recordSetRepo = mock[RecordSetRepository]
@ -151,7 +154,7 @@ class ZoneSyncHandlerSpec
recordChangeRepo, recordChangeRepo,
zoneChangeRepo, zoneChangeRepo,
zoneRepo, zoneRepo,
_ => mockDNSLoader, mockBackendResolver,
(_, _) => mockVinylDNSLoader (_, _) => mockVinylDNSLoader
) )
@ -159,17 +162,20 @@ class ZoneSyncHandlerSpec
recordSetRepo, recordSetRepo,
recordChangeRepo, recordChangeRepo,
testZoneChange, testZoneChange,
_ => mockDNSLoader, mockBackendResolver,
(_, _) => mockVinylDNSLoader (_, _) => mockVinylDNSLoader
) )
override def beforeEach(): Unit = { override def beforeEach(): Unit = {
reset(recordSetRepo) reset(
reset(recordChangeRepo) recordSetRepo,
reset(zoneRepo) recordChangeRepo,
reset(zoneChangeRepo) zoneRepo,
reset(mockDNSLoader) zoneChangeRepo,
reset(mockVinylDNSLoader) mockDNSLoader,
mockVinylDNSLoader,
mockBackend
)
doReturn( doReturn(
IO(ListRecordSetResults(List(testRecord1), None, None, None, None, None, None, NameSort.ASC)) IO(ListRecordSetResults(List(testRecord1), None, None, None, None, None, None, NameSort.ASC))
@ -190,6 +196,9 @@ class ZoneSyncHandlerSpec
doReturn(() => IO(testDnsView)).when(mockDNSLoader).load doReturn(() => IO(testDnsView)).when(mockDNSLoader).load
doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load
doReturn(IO.pure(List(testRecord1, testRecord2)))
.when(mockBackend)
.loadZone(any[Zone], any[Int])
} }
"ZoneSyncHandler" should { "ZoneSyncHandler" should {
@ -197,6 +206,7 @@ class ZoneSyncHandlerSpec
doReturn(IO.pure(Right(testZoneChange))) doReturn(IO.pure(Right(testZoneChange)))
.when(zoneRepo) .when(zoneRepo)
.save(any[Zone]) .save(any[Zone])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val result = zoneSync(testZoneChange).unsafeRunSync() val result = zoneSync(testZoneChange).unsafeRunSync()
@ -222,6 +232,7 @@ class ZoneSyncHandlerSpec
} }
"process successful zone sync with no changes" in { "process successful zone sync with no changes" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
doReturn(IO.pure(Right(testZoneChange))) doReturn(IO.pure(Right(testZoneChange)))
.when(zoneRepo) .when(zoneRepo)
.save(any[Zone]) .save(any[Zone])
@ -252,6 +263,7 @@ class ZoneSyncHandlerSpec
} }
"handle failed zone sync" in { "handle failed zone sync" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
doReturn(() => IO.raiseError(new RuntimeException("Dns Failed"))) doReturn(() => IO.raiseError(new RuntimeException("Dns Failed")))
.when(mockVinylDNSLoader) .when(mockVinylDNSLoader)
.load .load
@ -285,6 +297,7 @@ class ZoneSyncHandlerSpec
"saveZoneAndChange" should { "saveZoneAndChange" should {
"save zone and zoneChange with given statuses" in { "save zone and zoneChange with given statuses" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
doReturn(IO.pure(Right(testZoneChange))).when(zoneRepo).save(testZoneChange.zone) doReturn(IO.pure(Right(testZoneChange))).when(zoneRepo).save(testZoneChange.zone)
ZoneSyncHandler.saveZoneAndChange(zoneRepo, zoneChangeRepo, testZoneChange).unsafeRunSync() ZoneSyncHandler.saveZoneAndChange(zoneRepo, zoneChangeRepo, testZoneChange).unsafeRunSync()
@ -300,6 +313,7 @@ class ZoneSyncHandlerSpec
} }
"handle duplicateZoneError" in { "handle duplicateZoneError" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
doReturn(IO.pure(Left(DuplicateZoneError("error")))).when(zoneRepo).save(testZoneChange.zone) doReturn(IO.pure(Left(DuplicateZoneError("error")))).when(zoneRepo).save(testZoneChange.zone)
ZoneSyncHandler.saveZoneAndChange(zoneRepo, zoneChangeRepo, testZoneChange).unsafeRunSync() ZoneSyncHandler.saveZoneAndChange(zoneRepo, zoneChangeRepo, testZoneChange).unsafeRunSync()
@ -317,42 +331,41 @@ class ZoneSyncHandlerSpec
"runSync" should { "runSync" should {
"send the correct zone to the DNSZoneViewLoader" in { "send the correct zone to the DNSZoneViewLoader" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val captor = ArgumentCaptor.forClass(classOf[Zone]) val captor = ArgumentCaptor.forClass(classOf[Zone])
val dnsLoader = mock[Zone => DnsZoneViewLoader]
doReturn(mockDNSLoader).when(dnsLoader).apply(any[Zone])
ZoneSyncHandler ZoneSyncHandler
.runSync( .runSync(
recordSetRepo, recordSetRepo,
recordChangeRepo, recordChangeRepo,
testZoneChange, testZoneChange,
dnsLoader, mockBackendResolver,
(_, _) => mockVinylDNSLoader (_, _) => mockVinylDNSLoader
) )
.unsafeRunSync() .unsafeRunSync()
verify(dnsLoader).apply(captor.capture()) verify(mockBackend).loadZone(captor.capture(), any[Int])
val req = captor.getValue val req = captor.getValue
req shouldBe testZone.copy(status = ZoneStatus.Syncing) req shouldBe testZone.copy(status = ZoneStatus.Syncing)
} }
"load the dns zone from DNSZoneViewLoader" in { "load the dns zone from DNSZoneViewLoader" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
ZoneSyncHandler ZoneSyncHandler
.runSync( .runSync(
recordSetRepo, recordSetRepo,
recordChangeRepo, recordChangeRepo,
testZoneChange, testZoneChange,
_ => mockDNSLoader, mockBackendResolver,
(_, _) => mockVinylDNSLoader (_, _) => mockVinylDNSLoader
) )
.unsafeRunSync() .unsafeRunSync()
verify(mockDNSLoader, times(1)).load verify(mockBackend, times(1)).loadZone(any[Zone], any[Int])
} }
"Send the correct zone to the VinylDNSZoneViewLoader" in { "Send the correct zone to the VinylDNSZoneViewLoader" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val zoneCaptor = ArgumentCaptor.forClass(classOf[Zone]) val zoneCaptor = ArgumentCaptor.forClass(classOf[Zone])
val repoCaptor = ArgumentCaptor.forClass(classOf[RecordSetRepository]) val repoCaptor = ArgumentCaptor.forClass(classOf[RecordSetRepository])
@ -364,7 +377,7 @@ class ZoneSyncHandlerSpec
recordSetRepo, recordSetRepo,
recordChangeRepo, recordChangeRepo,
testZoneChange, testZoneChange,
_ => mockDNSLoader, mockBackendResolver,
vinyldnsLoader vinyldnsLoader
) )
.unsafeRunSync() .unsafeRunSync()
@ -375,26 +388,14 @@ class ZoneSyncHandlerSpec
} }
"load the dns zone from VinylDNSZoneViewLoader" in { "load the dns zone from VinylDNSZoneViewLoader" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
runSync.unsafeRunSync() runSync.unsafeRunSync()
verify(mockVinylDNSLoader, times(1)).load verify(mockVinylDNSLoader, times(1)).load
} }
"compute the diff correctly" in {
val captor = ArgumentCaptor.forClass(classOf[ZoneView])
val testVinylDNSView = mock[ZoneView]
doReturn(List(testRecordSetChange)).when(testVinylDNSView).diff(any[ZoneView])
doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load
runSync.unsafeRunSync()
verify(testVinylDNSView).diff(captor.capture())
val req = captor.getValue
req shouldBe testDnsView
}
"save the record changes to the recordChangeRepo" in { "save the record changes to the recordChangeRepo" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val captor = ArgumentCaptor.forClass(classOf[ChangeSet]) val captor = ArgumentCaptor.forClass(classOf[ChangeSet])
runSync.unsafeRunSync() runSync.unsafeRunSync()
@ -404,6 +405,7 @@ class ZoneSyncHandlerSpec
} }
"save the record sets to the recordSetRepo" in { "save the record sets to the recordSetRepo" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val captor = ArgumentCaptor.forClass(classOf[ChangeSet]) val captor = ArgumentCaptor.forClass(classOf[ChangeSet])
runSync.unsafeRunSync() runSync.unsafeRunSync()
@ -413,6 +415,7 @@ class ZoneSyncHandlerSpec
} }
"returns the zone as active and sets the latest sync" in { "returns the zone as active and sets the latest sync" in {
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val testVinylDNSView = ZoneView(testZone, List(testRecord1, testRecord2)) val testVinylDNSView = ZoneView(testZone, List(testRecord1, testRecord2))
doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load
val result = runSync.unsafeRunSync() val result = runSync.unsafeRunSync()
@ -436,6 +439,7 @@ class ZoneSyncHandlerSpec
doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load
doReturn(IO(correctChangeSet)).when(recordSetRepo).apply(captor.capture()) doReturn(IO(correctChangeSet)).when(recordSetRepo).apply(captor.capture())
doReturn(IO(correctChangeSet)).when(recordChangeRepo).save(any[ChangeSet]) doReturn(IO(correctChangeSet)).when(recordChangeRepo).save(any[ChangeSet])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
runSync.unsafeRunSync() runSync.unsafeRunSync()
@ -457,6 +461,7 @@ class ZoneSyncHandlerSpec
doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load doReturn(() => IO(testVinylDNSView)).when(mockVinylDNSLoader).load
doReturn(IO(correctChangeSet)).when(recordSetRepo).apply(captor.capture()) doReturn(IO(correctChangeSet)).when(recordSetRepo).apply(captor.capture())
doReturn(IO(correctChangeSet)).when(recordChangeRepo).save(any[ChangeSet]) doReturn(IO(correctChangeSet)).when(recordChangeRepo).save(any[ChangeSet])
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val zoneChange = ZoneChange(testReverseZone, testReverseZone.account, ZoneChangeType.Sync) val zoneChange = ZoneChange(testReverseZone, testReverseZone.account, ZoneChangeType.Sync)
@ -465,7 +470,7 @@ class ZoneSyncHandlerSpec
recordSetRepo, recordSetRepo,
recordChangeRepo, recordChangeRepo,
zoneChange, zoneChange,
_ => mockDNSLoader, mockBackendResolver,
(_, _) => mockVinylDNSLoader (_, _) => mockVinylDNSLoader
) )
.unsafeRunSync() .unsafeRunSync()
@ -477,6 +482,7 @@ class ZoneSyncHandlerSpec
doReturn(() => IO.raiseError(new RuntimeException("Dns Failed"))) doReturn(() => IO.raiseError(new RuntimeException("Dns Failed")))
.when(mockVinylDNSLoader) .when(mockVinylDNSLoader)
.load .load
doReturn(mockBackend).when(mockBackendResolver).resolve(any[Zone])
val result = runSync.unsafeRunSync() val result = runSync.unsafeRunSync()
result.status shouldBe ZoneChangeStatus.Failed result.status shouldBe ZoneChangeStatus.Failed

View File

@ -33,3 +33,7 @@ class NoOpCrypto(val config: Config) extends CryptoAlgebra {
def encrypt(value: String): String = value def encrypt(value: String): String = value
def decrypt(value: String): String = value def decrypt(value: String): String = value
} }
object NoOpCrypto {
val instance = new NoOpCrypto()
}

View File

@ -19,6 +19,13 @@ package vinyldns.core.domain
import DomainHelpers.{ensureTrailingDot, removeWhitespace} import DomainHelpers.{ensureTrailingDot, removeWhitespace}
case class Fqdn(fqdn: String) { case class Fqdn(fqdn: String) {
// Everything up to the first dot / period
def firstLabel: String = fqdn.substring(0, fqdn.indexOf('.'))
// Everything up to the first dot, includes the dot to make it absolute
def firstLabelAbsolute: String = fqdn.substring(0, fqdn.indexOf('.') + 1)
override def equals(obj: Any): Boolean = override def equals(obj: Any): Boolean =
obj match { obj match {
case Fqdn(otherFqdn) => otherFqdn.toLowerCase == fqdn.toLowerCase case Fqdn(otherFqdn) => otherFqdn.toLowerCase == fqdn.toLowerCase
@ -31,4 +38,22 @@ case class Fqdn(fqdn: String) {
case object Fqdn { case object Fqdn {
def apply(fqdn: String): Fqdn = def apply(fqdn: String): Fqdn =
new Fqdn(ensureTrailingDot(removeWhitespace(fqdn))) new Fqdn(ensureTrailingDot(removeWhitespace(fqdn)))
// Combines record name and zone name to create a valid fqdn
def merge(recordName: String, zoneName: String): Fqdn = {
def dropTrailingDot(value: String): String =
if (value.endsWith(".")) value.dropRight(1) else value
val rname = dropTrailingDot(recordName)
val zname = dropTrailingDot(zoneName)
val zIndex = rname.lastIndexOf(zname)
if (zIndex > 0) {
// zone name already there, or record name = zone name, so just return
Fqdn(rname + ".")
} else {
// zone name not in record name so combine
Fqdn(s"$rname.$zname.")
}
}
} }

View File

@ -0,0 +1,77 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import cats.effect.IO
import vinyldns.core.domain.record.{RecordSet, RecordSetChange}
import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.zone.Zone
/**
* Provides the backend interface to work with any kind of DNS backend
*
* Implement this interface for your own backend. The default backend is the DnsBackend that talks DDNS.
*/
trait Backend {
/**
* Identifies this backend
* @return The id for the backend
*/
def id: String
/**
* Does a lookup for a record given the record name, zone name, and record type
*
* The record name + zone name should form the FQDN
*
* @param name The name of the record (without the zone - e.g. www)
* @param zoneName The full domain name (e.g. example.com)
* @param typ The type of record (e.g. AAAA)
* @return A list of record sets matching the name, empty if not found
*/
def resolve(name: String, zoneName: String, typ: RecordType): IO[List[RecordSet]]
/**
* Applies a single record set change against the DNS backend
*
* @param change A RecordSetChange to apply. Note: the key for a record set is the record name + type.
* A single RecordSetChange can add or remove multiple individual records in a record set at one time.
* @return A BackendResponse that is backend provider specific
*/
def applyChange(change: RecordSetChange): IO[BackendResponse]
/**
* Loads all record sets in a zone. Used typically for zone syncs.
*
* Note, this will cause memory issues for large zones (100,000s of records). Need to make
* zone sync memory safe before changing this
*
* @param zone The zone to load
* @param maxZoneSize The maximum number of records that we allow loading, typically configured
* @return All record sets in the zone
*/
def loadZone(zone: Zone, maxZoneSize: Int): IO[List[RecordSet]]
/**
* Indicates if the zone is present in the backend
*
* @param zone The zone to check if exists
* @return true if it exists; false otherwise
*/
def zoneExists(zone: Zone): IO[Boolean]
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import cats.effect.{Blocker, ContextShift, IO}
import com.typesafe.config.Config
import pureconfig._
import pureconfig.generic.auto._
import pureconfig.module.catseffect.syntax._
/* The main VinylDNS backend configs, loaded by the BackendRegistry */
final case class BackendConfigs(
defaultBackendId: String,
backendProviders: List[BackendProviderConfig]
)
object BackendConfigs {
def load(config: Config)(implicit cs: ContextShift[IO]): IO[BackendConfigs] =
Blocker[IO].use(
ConfigSource.fromConfig(config).loadF[IO, BackendConfigs](_)
)
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import cats.implicits._
import cats.effect.IO
import org.slf4j.LoggerFactory
object BackendLoader {
private val logger = LoggerFactory.getLogger("BackendLoader")
def load(configs: List[BackendProviderConfig]): IO[List[BackendProvider]] = {
def loadOne(config: BackendProviderConfig): IO[BackendProvider] =
for {
_ <- IO(logger.error(s"Attempting to load backend ${config.className}"))
provider <- IO(
Class
.forName(config.className)
.getDeclaredConstructor()
.newInstance()
.asInstanceOf[BackendProviderLoader]
)
backend <- provider.load(config)
} yield backend
configs.traverse(loadOne)
}
}

View File

@ -0,0 +1,46 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import vinyldns.core.domain.zone.Zone
/**
* Implemented by each provider, provides a means of looking up a `BackendConnection`
* as well as showing which backend ids are registered on this provider
*/
trait BackendProvider {
/**
* Given a zone, returns a connection to the zone, returns None if cannot connect
*
* @param zone The zone to attempt to connect to
* @return A backend that is usable, or None if it could not connect
*/
def connect(zone: Zone): Option[Backend]
/**
* Given a backend id, looks up the backend for this provider if it exists
*
* @return A backend that is usable, or None if could not connect
*/
def connectById(backendId: String): Option[Backend]
/**
* @return The backend ids loaded with this provider
*/
def ids: List[String]
}

View File

@ -0,0 +1,29 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import com.typesafe.config.Config
/**
* Config section loaded for a specific backend
* @param className The fully qualified class name of the provider to be loaded
* @param settings A generic typesafe Config object that holds settings to be interpreted by the provider
*/
final case class BackendProviderConfig(
className: String,
settings: Config
)

View File

@ -0,0 +1,39 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import cats.effect.IO
/**
* To be implemented by other DNS Backend providers. This handles the loading of the backend config,
* typically comprised of multiple connections.
*
* All takes place inside IO, allowing implementers to do anything they need to ready the backend
* for integration with VinylDNS
*/
trait BackendProviderLoader {
/**
* Loads a backend based on the provided config so that it is ready to use
* This is internally used typically during startup
*
* @param config The BackendConfig, has settings that are specific to this backend
*
* @return A ready-to-use Backend instance, or does an IO.raiseError if something bad occurred.
*/
def load(config: BackendProviderConfig): IO[BackendProvider]
}

View File

@ -0,0 +1,109 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
import cats.data.NonEmptyList
import cats.effect.IO
import cats.implicits._
import vinyldns.core.domain.zone.Zone
import vinyldns.core.health.HealthCheck
import vinyldns.core.health.HealthCheck.HealthCheck
/**
* Provides the means to discover backends for zones
*/
trait BackendResolver {
/**
* Attempts to get the backend for a given zone, falls back to
* using the default-backend-id if no zones can be found
*
* @param zone A `Zone` to get a backend for
* @return A working `Backend`, the default if necessary
*/
def resolve(zone: Zone): Backend
/**
* Performs whatever health check considered necessary to ensure that the backends are in good health
*
* @param timeout Timeout in seconds to wait before raising an error
*
* @return A HealthCheck that can be run to determine the health of the registered backends
*/
def healthCheck(timeout: Int): HealthCheck
/**
* Determines if a given backend id is registered
*
* @param backendId The id to lookup
*
* @return true if it is registered; false otherwise
*/
def isRegistered(backendId: String): Boolean
/**
* @return All of the backend ids registered
*/
def ids: NonEmptyList[String]
}
object BackendResolver {
def apply(configs: BackendConfigs): IO[BackendResolver] =
for {
backends <- BackendLoader.load(configs.backendProviders)
defaultConn <- IO.fromOption(
backends.collectFirstSome(_.connectById(configs.defaultBackendId))
)(
new RuntimeException(
s"Unable to find default backend for configured id '${configs.defaultBackendId}''"
)
)
} yield new BackendResolver {
/**
* Attempts to get the backend for a given zone, returns `None` if not found
*
* @param zone A `Zone` to get a backend for
* @return A working `Backend`, or `None` if the backend could not be found for this zone
*/
def resolve(zone: Zone): Backend =
backends.collectFirstSome(_.connect(zone)).getOrElse(defaultConn)
/**
* Performs whatever health check considered necessary to ensure that the backends are in good health
*
* @param timeout Timeout in seconds to wait before raising an error
* @return A HealthCheck that can be run to determine the health of the registered backends
*/
def healthCheck(timeout: Int): HealthCheck =
IO.pure(().asRight[HealthCheck.HealthCheckError])
/**
* Determines if a given backend id is registered
*
* @param backendId The id to lookup
* @return true if it is registered; false otherwise
*/
def isRegistered(backendId: String): Boolean =
backends.collectFirstSome(_.connectById(backendId)).isDefined
/**
* @return All of the backend ids registered
*/
val ids: NonEmptyList[String] =
NonEmptyList(defaultConn.id, backends.toList.flatMap(_.ids)).distinct
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.domain.backend
/* Response from applying a change to a backend */
sealed trait BackendResponse
object BackendResponse {
/* Indicates the backend request worked properly */
final case class NoError(message: String) extends BackendResponse
/* Indicates there was a failure that maybe recoverable with a try again */
final case class Retry(message: String) extends BackendResponse
}

View File

@ -19,38 +19,106 @@ package vinyldns.core.domain.record
import scodec.bits.ByteVector import scodec.bits.ByteVector
import vinyldns.core.domain.Fqdn import vinyldns.core.domain.Fqdn
sealed trait RecordData import scala.util.Try
import RecordData._
import vinyldns.core.domain.record.RecordType._
final case class AData(address: String) extends RecordData sealed trait RecordData {
def toString: String
}
object RecordData {
def toInt(value: String): Option[Int] =
Try(value.toInt).toOption
final case class AAAAData(address: String) extends RecordData def toLong(value: String): Option[Long] =
Try(value.toLong).toOption
final case class CNAMEData(cname: Fqdn) extends RecordData def fromString(value: String, typ: RecordType): Option[RecordData] = typ match {
case A => AData.fromString(value)
case AAAA => AAAAData.fromString(value)
case CNAME => CNAMEData.fromString(value)
case DS => DSData.fromString(value)
case MX => MXData.fromString(value)
case NAPTR => NAPTRData.fromString(value)
case NS => NSData.fromString(value)
case PTR => PTRData.fromString(value)
case SPF => SPFData.fromString(value)
case SRV => SRVData.fromString(value)
case SSHFP => SSHFPData.fromString(value)
case TXT => TXTData.fromString(value)
case UNKNOWN => None
}
}
final case class AData(address: String) extends RecordData {
override def toString: String = address
}
object AData {
def fromString(value: String): Option[AData] = Option(value).map(AData(_))
}
final case class AAAAData(address: String) extends RecordData {
override def toString: String = address
}
object AAAAData {
def fromString(value: String): Option[AAAAData] = Option(value).map(AAAAData(_))
}
final case class CNAMEData(cname: Fqdn) extends RecordData {
override def toString: String = cname.fqdn
}
object CNAMEData { object CNAMEData {
def apply(cname: Fqdn): CNAMEData = def apply(cname: Fqdn): CNAMEData =
new CNAMEData(cname) new CNAMEData(cname)
def fromString(value: String): Option[CNAMEData] =
Option(value).map(Fqdn.apply).map(CNAMEData.apply)
} }
final case class MXData(preference: Integer, exchange: Fqdn) extends RecordData final case class MXData(preference: Integer, exchange: Fqdn) extends RecordData {
override def toString: String = s"$preference ${exchange.fqdn}"
}
object MXData { object MXData {
def apply(preference: Integer, exchange: Fqdn): MXData = def apply(preference: Integer, exchange: Fqdn): MXData =
new MXData(preference, exchange) new MXData(preference, exchange)
/* Assumes format preference fqdn, e.g. 10 www.example.com; otherwise returns None */
def fromString(value: String): Option[MXData] =
Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 2) {
None
} else {
toInt(parts(0)).map { pref =>
new MXData(pref, Fqdn(parts(1)))
}
}
}
} }
final case class NSData(nsdname: Fqdn) extends RecordData final case class NSData(nsdname: Fqdn) extends RecordData {
override def toString: String = nsdname.fqdn
}
object NSData { object NSData {
def apply(nsdname: Fqdn): NSData = def apply(nsdname: Fqdn): NSData =
new NSData(nsdname) new NSData(nsdname)
def fromString(value: String): Option[NSData] =
Option(value).map(Fqdn.apply).map(NSData.apply)
} }
final case class PTRData(ptrdname: Fqdn) extends RecordData final case class PTRData(ptrdname: Fqdn) extends RecordData {
override def toString: String = ptrdname.fqdn
}
object PTRData { object PTRData {
def apply(ptrdname: Fqdn): PTRData = def apply(ptrdname: Fqdn): PTRData =
new PTRData(ptrdname) new PTRData(ptrdname)
def fromString(value: String): Option[PTRData] =
Option(value).map(Fqdn.apply).map(PTRData.apply)
} }
final case class SOAData( final case class SOAData(
@ -61,16 +129,67 @@ final case class SOAData(
retry: Long, retry: Long,
expire: Long, expire: Long,
minimum: Long minimum: Long
) extends RecordData ) extends RecordData {
override def toString: String = s"${mname.fqdn} ${rname} $serial $refresh $retry $expire $minimum"
}
object SOAData {
def fromString(value: String): Option[SOAData] =
Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 7) {
None
} else {
for {
serial <- toLong(parts(2))
refresh <- toLong(parts(3))
retry <- toLong(parts(4))
expire <- toLong(parts(5))
minimum <- toLong(parts(6))
} yield SOAData(
Fqdn(parts(0)),
parts(1),
serial,
refresh,
retry,
expire,
minimum
)
}
}
}
final case class SPFData(text: String) extends RecordData final case class SPFData(text: String) extends RecordData {
override def toString: String = text
}
object SPFData {
def fromString(value: String): Option[SPFData] = Option(value).map(SPFData(_))
}
final case class SRVData(priority: Integer, weight: Integer, port: Integer, target: Fqdn) final case class SRVData(priority: Integer, weight: Integer, port: Integer, target: Fqdn)
extends RecordData extends RecordData {
override def toString: String = s"$priority $weight $port ${target.fqdn}"
}
object SRVData { object SRVData {
def apply(priority: Integer, weight: Integer, port: Integer, target: Fqdn): SRVData = def fromString(value: String): Option[SRVData] =
new SRVData(priority, weight, port, target) Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 7) {
None
} else {
for {
priority <- toInt(parts(0))
weight <- toInt(parts(1))
port <- toInt(parts(2))
target = Fqdn(parts(3))
} yield SRVData(
priority,
weight,
port,
target
)
}
}
} }
final case class NAPTRData( final case class NAPTRData(
@ -80,7 +199,9 @@ final case class NAPTRData(
service: String, service: String,
regexp: String, regexp: String,
replacement: Fqdn replacement: Fqdn
) extends RecordData ) extends RecordData {
override def toString: String = s"$order $preference $flags $service $regexp ${replacement.fqdn}"
}
object NAPTRData { object NAPTRData {
def apply( def apply(
@ -92,11 +213,51 @@ object NAPTRData {
replacement: Fqdn replacement: Fqdn
): NAPTRData = ): NAPTRData =
new NAPTRData(order, preference, flags, service, regexp, replacement) new NAPTRData(order, preference, flags, service, regexp, replacement)
def fromString(value: String): Option[NAPTRData] =
Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 6) {
None
} else {
for {
order <- toInt(parts(0))
pref <- toInt(parts(1))
flags = parts(2)
service = parts(3)
reg = parts(4)
rep = Fqdn(parts(5))
} yield NAPTRData(order, pref, flags, service, reg, rep)
}
}
} }
final case class SSHFPData(algorithm: Integer, typ: Integer, fingerprint: String) extends RecordData final case class SSHFPData(algorithm: Integer, typ: Integer, fingerprint: String)
extends RecordData {
override def toString: String = s"$algorithm $typ $fingerprint"
}
object SSHFPData {
def fromString(value: String): Option[SSHFPData] =
Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 3) {
None
} else {
for {
alg <- toInt(parts(0))
typ <- toInt(parts(1))
fp = parts(2)
} yield SSHFPData(alg, typ, fp)
}
}
}
final case class TXTData(text: String) extends RecordData final case class TXTData(text: String) extends RecordData {
override def toString: String = text
}
object TXTData {
def fromString(value: String): Option[TXTData] = Option(value).map(TXTData(_))
}
sealed abstract class DigestType(val value: Int) sealed abstract class DigestType(val value: Int)
object DigestType { object DigestType {
@ -159,4 +320,22 @@ final case class DSData(
algorithm: DnsSecAlgorithm, algorithm: DnsSecAlgorithm,
digestType: DigestType, //digestid in DNSJava digestType: DigestType, //digestid in DNSJava
digest: ByteVector digest: ByteVector
) extends RecordData ) extends RecordData {
override def toString: String = s"$keyTag $algorithm $digestType $digest"
}
object DSData {
def fromString(value: String): Option[DSData] =
Option(value).flatMap { v =>
val parts = v.split(' ')
if (parts.length != 3) {
None
} else {
for {
kt <- toInt(parts(0))
alg <- toInt(parts(1)).map(DnsSecAlgorithm.apply)
dt <- toInt(parts(2)).map(DigestType.apply)
dig <- Some(ByteVector(parts(3).getBytes))
} yield DSData(kt, alg, dt, dig)
}
}
}

View File

@ -148,13 +148,13 @@ case class ZoneConnection(name: String, keyName: String, key: String, primarySer
copy(key = crypto.decrypt(key)) copy(key = crypto.decrypt(key))
} }
final case class DnsBackend( final case class LegacyDnsBackend(
id: String, id: String,
zoneConnection: ZoneConnection, zoneConnection: ZoneConnection,
transferConnection: ZoneConnection transferConnection: ZoneConnection
) { ) {
def encrypted(crypto: CryptoAlgebra): DnsBackend = copy( def encrypted(crypto: CryptoAlgebra): LegacyDnsBackend = copy(
zoneConnection = zoneConnection.encrypted(crypto), zoneConnection = zoneConnection.encrypted(crypto),
transferConnection = transferConnection.encrypted(crypto) transferConnection = transferConnection.encrypted(crypto)
) )
@ -163,5 +163,5 @@ final case class DnsBackend(
final case class ConfiguredDnsConnections( final case class ConfiguredDnsConnections(
defaultZoneConnection: ZoneConnection, defaultZoneConnection: ZoneConnection,
defaultTransferConnection: ZoneConnection, defaultTransferConnection: ZoneConnection,
dnsBackends: List[DnsBackend] dnsBackends: List[LegacyDnsBackend]
) )

View File

@ -0,0 +1,22 @@
vinyldns {
backend {
default-backend-id = "r53"
backend-providers = [
{
class-name = "vinyldns.route53.backend.Route53BackendProviderLoader"
settings = {
backends = [
{
id = "test"
access-key = "vinyldnsTest"
secret-key = "notNeededForSnsLocal"
service-endpoint = "http://127.0.0.1:19009"
signing-region = "us-east-1"
}
]
}
}
]
}
}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import com.amazonaws.services.route53.model.DeleteHostedZoneRequest
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import vinyldns.core.domain.zone.Zone
import scala.collection.JavaConverters._
import org.scalatest.OptionValues._
import org.scalatest.EitherValues._
import vinyldns.core.domain.backend.BackendResponse
import vinyldns.core.domain.{Fqdn, record}
import vinyldns.core.domain.record.{RecordSet, RecordType}
class Route53IntegrationSpec
extends AnyWordSpec
with BeforeAndAfterAll
with BeforeAndAfterEach
with Matchers {
import vinyldns.core.TestRecordSetData._
import vinyldns.core.TestZoneData._
private val testZone = Zone("example.com.", "test@test.com", backendId = Some("test"))
override def beforeAll(): Unit = {
deleteZone()
createZone()
}
private def testConnection: Route53Backend =
Route53Backend
.load(
Route53BackendConfig("test", "access", "secret", "http://127.0.0.1:19009", "us-east-1")
)
.unsafeRunSync()
private def deleteZone(): Unit = {
val zoneIds = testConnection.client.listHostedZones().getHostedZones.asScala.map(_.getId).toList
zoneIds.foreach { id =>
testConnection.client.deleteHostedZone(new DeleteHostedZoneRequest().withId(id))
}
}
private def createZone(): Unit =
testConnection.createZone(testZone).unsafeRunSync()
private def checkRecordExists(rs: RecordSet, zone: Zone): Unit = {
val resolveResult =
testConnection.resolve(rs.name, zone.name, rs.typ).unsafeRunSync().headOption.value
resolveResult.records should contain theSameElementsAs rs.records
resolveResult.name shouldBe rs.name
resolveResult.ttl shouldBe rs.ttl
resolveResult.typ shouldBe rs.typ
}
private def checkRecordNotExists(rs: RecordSet, zone: Zone): Unit =
testConnection.resolve(rs.name, zone.name, rs.typ).unsafeRunSync() shouldBe empty
private def testRecordSet(rs: RecordSet, zone: Zone): Unit = {
val conn = testConnection
val testRecord = rs.copy(zoneId = zone.id)
val change = makeTestAddChange(testRecord, zone, "test-user")
val result = conn.applyChange(change).unsafeRunSync()
result shouldBe a[BackendResponse.NoError]
// We should be able to resolve now
checkRecordExists(testRecord, zone)
val del = makePendingTestDeleteChange(testRecord, zone, "test-user")
conn.applyChange(del).unsafeRunSync()
// Record should not be found
checkRecordNotExists(testRecord, zone)
}
"Route53 Connections" should {
"return nothing if the zone does not exist" in {
testConnection.resolve("foo", "bar", RecordType.A).unsafeRunSync() shouldBe empty
}
"work for a" in {
testRecordSet(rsOk, testZone)
}
"work for aaaa" in {
testRecordSet(aaaa, testZone)
}
"work for cname" in {
testRecordSet(cname, testZone)
}
"work for naptr" in {
testRecordSet(naptr, testZone)
}
"work for mx" in {
val testMxData = record.MXData(10, Fqdn("mx.example.com."))
val testMx = mx.copy(records = List(testMxData))
testRecordSet(testMx, testZone)
}
"work for txt" in {
testRecordSet(txt, testZone)
}
"check if zone exists" in {
val notFound = Zone("blah.foo.", "test@test.com", backendId = Some("test"))
testConnection.zoneExists(notFound).unsafeRunSync() shouldBe false
testConnection.zoneExists(testZone).unsafeRunSync() shouldBe true
}
"fail when applying the change and it does not exist" in {
val testRecord = aaaa
val testZone = okZone
val change = makeTestAddChange(testRecord, testZone, "test-user")
val result = testConnection.applyChange(change).attempt.unsafeRunSync()
result.left.value shouldBe a[Route53BackendResponse.ZoneNotFoundError]
}
}
}

View File

@ -0,0 +1,284 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import cats.data.OptionT
import cats.effect.IO
import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicAWSCredentials}
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.amazonaws.handlers.AsyncHandler
import com.amazonaws.services.route53.{AmazonRoute53Async, AmazonRoute53AsyncClientBuilder}
import com.amazonaws.services.route53.model._
import com.amazonaws.{AmazonWebServiceRequest, AmazonWebServiceResult}
import org.slf4j.LoggerFactory
import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.backend.{Backend, BackendResponse}
import vinyldns.core.domain.record.RecordSetChangeType.RecordSetChangeType
import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.record.{RecordSet, RecordSetChange, RecordSetChangeType}
import vinyldns.core.domain.zone.{Zone, ZoneStatus}
import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
/**
* Backend for a single AWS account
*
* @param id VinylDNS backend identifier used to connect to route 53
* @param hostedZones A list of hosted zones, loaded when the application is started. Necessary
* as most interactions with Route53 go through the zone id, not the zone name.
* This will be used as a cache, and on cache miss will lookup the zone in real time
* @param client A route 53 client with credentials that can talk to this route 53 aws account
*/
class Route53Backend(
val id: String,
hostedZones: List[HostedZone],
val client: AmazonRoute53Async
) extends Backend
with Route53Conversions {
import Route53Backend.r53
private val logger = LoggerFactory.getLogger(classOf[Route53Backend])
/* Concurrent friendly map */
private val zoneMap: TrieMap[String, String] = TrieMap(
hostedZones.map(z => z.getName -> z.getId): _*
)
/* Lookup in the local cache, if a new zone is added since start, we have to retrieve it in real time */
private def lookupHostedZone(zoneName: String): OptionT[IO, String] = {
// pain but we must parse to use the hosted zone ids from the cache
def parseHostedZoneId(hzid: String): String = {
val lastSlash = hzid.lastIndexOf('/')
if (lastSlash > 0) {
hzid.substring(lastSlash + 1)
} else {
hzid
}
}
OptionT.fromOption[IO](zoneMap.get(zoneName)).orElseF {
r53(
new ListHostedZonesByNameRequest().withDNSName(zoneName),
client.listHostedZonesByNameAsync
).map { result =>
// We must parse the hosted zone id which is annoying
val found = result.getHostedZones.asScala.toList.headOption.map { hz =>
val hzid = parseHostedZoneId(hz.getId)
// adds the hozted zone name and id to our cache if not present
zoneMap.putIfAbsent(hz.getName, hzid)
hzid
}
if (found.isEmpty) {
logger.warn(s"Unable to find hosted zone for '$zoneName'")
}
found
}
}
}
/**
* Does a lookup for a record given the record name, zone name, and record type
*
* The record name + zone name should form the FQDN
*
* @param name The name of the record (without the zone - e.g. www)
* @param zoneName The full domain name (e.g. example.com)
* @param typ The type of record (e.g. AAAA)
* @return A list of record sets matching the name, empty if not found
*/
def resolve(name: String, zoneName: String, typ: RecordType): IO[List[RecordSet]] = {
for {
hostedZoneId <- lookupHostedZone(zoneName)
awsRRType <- OptionT.fromOption[IO](toRoute53RecordType(typ))
fqdn = Fqdn.merge(name, zoneName).fqdn
result <- OptionT.liftF {
r53(
new ListResourceRecordSetsRequest()
.withHostedZoneId(hostedZoneId)
.withStartRecordName(fqdn)
.withStartRecordType(awsRRType),
client.listResourceRecordSetsAsync
)
}
} yield toVinylRecordSets(result.getResourceRecordSets, zoneName: String)
}.getOrElse(Nil)
/**
* Applies a single record set change against the DNS backend
*
* @param change A RecordSetChange to apply. Note: the key for a record set is the record name + type.
* A single RecordSetChange can add or remove multiple individual records in a record set at one time.
* @return A BackendResponse that is backend provider specific
*/
def applyChange(change: RecordSetChange): IO[BackendResponse] = {
def changeAction(typ: RecordSetChangeType): ChangeAction = typ match {
case RecordSetChangeType.Create => ChangeAction.CREATE
case RecordSetChangeType.Update => ChangeAction.UPSERT
case RecordSetChangeType.Delete => ChangeAction.DELETE
}
def changeRequest(
typ: RecordSetChangeType,
rs: ResourceRecordSet
): ChangeResourceRecordSetsRequest = {
logger.debug(s"applying change to zone, record set is $rs")
new ChangeResourceRecordSetsRequest().withChangeBatch(
new ChangeBatch().withChanges(
new Change().withAction(changeAction(typ)).withResourceRecordSet(rs)
)
)
}
// We want to FAIL if unrecoverable errors occur so that the change ultimately is marked as failed
for {
hostedZoneId <- lookupHostedZone(change.zone.name).value.flatMap {
case Some(x) => IO(x)
case None =>
IO.raiseError(
Route53BackendResponse.ZoneNotFoundError(
s"Unable to find hosted zone for zone name ${change.zone.name}"
)
)
}
r53RecordSet <- IO.fromOption(toR53RecordSet(change.zone, change.recordSet))(
Route53BackendResponse.ConversionError(
s"Unable to convert record set to route 53 format for ${change.recordSet}"
)
)
result <- r53(
changeRequest(change.changeType, r53RecordSet).withHostedZoneId(hostedZoneId),
client.changeResourceRecordSetsAsync
).map { response =>
logger.debug(s"applied record change $change, change result is ${response.getChangeInfo}")
BackendResponse.NoError(response.toString)
}
} yield result
}
/**
* Loads all record sets in a zone. Used typically for zone syncs.
*
* Note, this will cause memory issues for large zones (100,000s of records). Need to make
* zone sync memory safe before changing this
*
* @param zone The zone to load
* @param maxZoneSize The maximum number of records that we allow loading, typically configured
* @return All record sets in the zone
*/
def loadZone(zone: Zone, maxZoneSize: Int): IO[List[RecordSet]] = {
// Loads a single page, up to 100 record sets
def loadPage(request: ListResourceRecordSetsRequest): IO[ListResourceRecordSetsResult] =
r53(
request,
client.listResourceRecordSetsAsync
)
// recursively pages through, exits once we hit the last page
def recurseLoadNextPage(
request: ListResourceRecordSetsRequest,
result: ListResourceRecordSetsResult,
acc: List[RecordSet]
): IO[List[RecordSet]] = {
val updatedAcc = acc ++ toVinylRecordSets(result.getResourceRecordSets, zone.name)
// Here is our base case right here, getIsTruncated returns true if there are more records
if (result.getIsTruncated) {
loadPage(
request
.withStartRecordName(result.getNextRecordName)
.withStartRecordType(result.getNextRecordType)
).flatMap(nextResult => recurseLoadNextPage(request, nextResult, updatedAcc))
} else {
IO(updatedAcc)
}
}
for {
hz <- lookupHostedZone(zone.name)
recordSets <- OptionT.liftF {
val req = new ListResourceRecordSetsRequest().withHostedZoneId(hz)
// recurse to load all pages
loadPage(req).flatMap(recurseLoadNextPage(req, _, Nil))
}
} yield recordSets
}.getOrElse(Nil)
/**
* Indicates if the zone is present in the backend
*
* @param zone The zone to check if exists
* @return true if it exists; false otherwise
*/
def zoneExists(zone: Zone): IO[Boolean] = lookupHostedZone(zone.name).isDefined
/* Note: naive implementation to assist in testing, not meant for production yet */
def createZone(zone: Zone): IO[Zone] =
for {
result <- r53(
new CreateHostedZoneRequest().withCallerReference(zone.id).withName(zone.name),
client.createHostedZoneAsync
)
_ <- IO(logger.info(s"create zone result is $result"))
} yield zone.copy(status = ZoneStatus.Active)
}
object Route53Backend {
/* Convenience method for working async with AWS */
def r53[A <: AmazonWebServiceRequest, B <: AmazonWebServiceResult[_]](
request: A,
f: (A, AsyncHandler[A, B]) => java.util.concurrent.Future[B]
): IO[B] =
IO.async[B] { complete: (Either[Throwable, B] => Unit) =>
val asyncHandler = new AsyncHandler[A, B] {
def onError(exception: Exception): Unit = complete(Left(exception))
def onSuccess(request: A, result: B): Unit = complete(Right(result))
}
f(request, asyncHandler)
}
// Loads a Route53 backend
def load(config: Route53BackendConfig): IO[Route53Backend] = {
val clientIO = IO {
AmazonRoute53AsyncClientBuilder.standard
.withEndpointConfiguration(
new EndpointConfiguration(config.serviceEndpoint, config.signingRegion)
)
.withCredentials(
new AWSStaticCredentialsProvider(
new BasicAWSCredentials(config.accessKey, config.secretKey)
)
)
.build()
}
// Connect to the client AND load the zones
for {
client <- clientIO
result <- r53(
new ListHostedZonesRequest(),
client.listHostedZonesAsync
)
} yield new Route53Backend(config.id, result.getHostedZones.asScala.toList, client)
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import vinyldns.core.domain.backend.{Backend, BackendProvider}
import vinyldns.core.domain.zone.Zone
class Route53BackendProvider(connections: List[Route53Backend]) extends BackendProvider {
private val connMap: Map[String, Route53Backend] = connections.map(c => c.id -> c).toMap
/**
* Given a zone, returns a connection to the zone, returns None if cannot connect
*
* @param zone The zone to attempt to connect to
* @return A backend that is usable, or None if it could not connect
*/
def connect(zone: Zone): Option[Backend] =
// only way to connect is via backend id right now
zone.backendId.flatMap(connectById)
/**
* Given a backend id, looks up the backend for this provider if it exists
*
* @return A backend that is usable, or None if could not connect
*/
def connectById(backendId: String): Option[Backend] =
connMap.get(backendId)
/**
* @return The backend ids loaded with this provider
*/
def ids: List[String] = connMap.keys.toList
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import cats.implicits._
import cats.effect.{ContextShift, IO}
import vinyldns.core.domain.backend.{BackendProviderConfig, BackendProvider, BackendProviderLoader}
class Route53BackendProviderLoader extends BackendProviderLoader {
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
/**
* Loads a backend based on the provided config so that it is ready to use
* This is internally used typically during startup
*
* @param config The BackendConfig, has settings that are specific to this backend
* @return A ready-to-use Backend instance, or does an IO.raiseError if something bad occurred.
*/
def load(config: BackendProviderConfig): IO[BackendProvider] =
Route53ProviderConfig.load(config.settings).flatMap { bec =>
bec.backends.traverse(Route53Backend.load).map { conns =>
new Route53BackendProvider(conns)
}
}
}

View File

@ -0,0 +1,23 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
sealed trait Route53BackendResponse
object Route53BackendResponse {
final case class ZoneNotFoundError(message: String) extends Throwable(message)
final case class ConversionError(message: String) extends Throwable(message)
}

View File

@ -0,0 +1,93 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import com.amazonaws.services.route53.model.{RRType, ResourceRecord, ResourceRecordSet}
import org.joda.time.DateTime
import vinyldns.core.domain.Fqdn
import vinyldns.core.domain.record.{RecordData, RecordSet, RecordSetStatus}
import vinyldns.core.domain.record.RecordType.RecordType
import vinyldns.core.domain.record.RecordType._
import vinyldns.core.domain.zone.Zone
import scala.collection.JavaConverters._
trait Route53Conversions {
def toRoute53RecordType(typ: RecordType): Option[RRType] = typ match {
case A => Some(RRType.A)
case AAAA => Some(RRType.AAAA)
case CNAME => Some(RRType.CNAME)
case MX => Some(RRType.MX)
case NAPTR => Some(RRType.NAPTR)
case NS => Some(RRType.NS)
case PTR => Some(RRType.PTR)
case SPF => Some(RRType.SPF)
case SRV => Some(RRType.SRV)
case TXT => Some(RRType.TXT)
case _ => None
}
def toVinylRecordType(typ: RRType): RecordType = typ match {
case RRType.A => A
case RRType.AAAA => AAAA
case RRType.CNAME => CNAME
case RRType.MX => MX
case RRType.NAPTR => NAPTR
case RRType.NS => NS
case RRType.PTR => PTR
case RRType.SPF => SPF
case RRType.SRV => SRV
case RRType.TXT => TXT
case _ => UNKNOWN
}
def toVinyl(typ: RecordType, resourceRecord: ResourceRecord): Option[RecordData] =
RecordData.fromString(resourceRecord.getValue, typ)
def toVinylRecordSet(zoneName: String, r53RecordSet: ResourceRecordSet): RecordSet = {
val typ = toVinylRecordType(RRType.fromValue(r53RecordSet.getType))
RecordSet(
"unknown",
Fqdn.merge(r53RecordSet.getName, zoneName).firstLabel,
typ,
r53RecordSet.getTTL,
RecordSetStatus.Active,
DateTime.now,
Some(DateTime.now),
r53RecordSet.getResourceRecords.asScala.toList.flatMap(toVinyl(typ, _)),
fqdn = Some(r53RecordSet.getName)
)
}
def toVinylRecordSets(
r53RecordSets: java.util.List[ResourceRecordSet],
zoneName: String
): List[RecordSet] =
r53RecordSets.asScala.toList.map(toVinylRecordSet(zoneName, _))
def toR53RecordSet(zone: Zone, vinylRecordSet: RecordSet): Option[ResourceRecordSet] =
toRoute53RecordType(vinylRecordSet.typ).map { typ =>
new ResourceRecordSet()
.withName(Fqdn.merge(vinylRecordSet.name, zone.name).fqdn)
.withTTL(vinylRecordSet.ttl)
.withType(typ)
.withResourceRecords(
vinylRecordSet.records.map(rd => new ResourceRecord().withValue(rd.toString)).asJava
)
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.route53.backend
import cats.effect.{Blocker, ContextShift, IO}
import com.typesafe.config.Config
import pureconfig.ConfigSource
import pureconfig.generic.auto._
import pureconfig.module.catseffect.syntax.CatsEffectConfigSource
// TODO: Add delegation set id and VPC options especially wrt CreateZone
final case class Route53BackendConfig(
id: String,
accessKey: String,
secretKey: String,
serviceEndpoint: String,
signingRegion: String
)
final case class Route53ProviderConfig(backends: List[Route53BackendConfig])
object Route53ProviderConfig {
def load(config: Config)(implicit cs: ContextShift[IO]): IO[Route53ProviderConfig] =
Blocker[IO].use(
ConfigSource.fromConfig(config).loadF[IO, Route53ProviderConfig](_)
)
}

View File

@ -94,6 +94,11 @@ object Dependencies {
"com.amazonaws" % "aws-java-sdk-sqs" % awsV withSources() "com.amazonaws" % "aws-java-sdk-sqs" % awsV withSources()
) )
lazy val r53Dependencies = Seq(
"com.amazonaws" % "aws-java-sdk-core" % awsV withSources(),
"com.amazonaws" % "aws-java-sdk-route53" % awsV withSources()
)
lazy val commonTestDependencies = Seq( lazy val commonTestDependencies = Seq(
"org.scalatest" %% "scalatest" % scalaTestV, "org.scalatest" %% "scalatest" % scalaTestV,
"org.scalacheck" %% "scalacheck" % "1.14.3", "org.scalacheck" %% "scalacheck" % "1.14.3",