diff --git a/modules/api/functional_test/live_tests/recordsets/list_recordsets_test.py b/modules/api/functional_test/live_tests/recordsets/list_recordsets_test.py index d4e5f2d06..2cdca68cd 100644 --- a/modules/api/functional_test/live_tests/recordsets/list_recordsets_test.py +++ b/modules/api/functional_test/live_tests/recordsets/list_recordsets_test.py @@ -207,6 +207,40 @@ def test_list_recordsets_default_size_is_100(rs_fixture): rs_fixture.check_recordsets_page_accuracy(list_results, size=17, offset=0, maxItems=100) +def test_list_recordsets_duplicate_names(rs_fixture): + """ + Test that paging keys work for records with duplicate names + """ + client = rs_fixture.client + ok_zone = rs_fixture.test_context + + created = [] + + try: + record_data_a = [{'address': '1.1.1.1'}] + record_data_txt = [{'text': 'some=value'}] + + record_json_a = get_recordset_json(ok_zone, '0', 'A', record_data_a, ttl=100) + record_json_txt = get_recordset_json(ok_zone, '0', 'TXT', record_data_txt, ttl=100) + + create_response = client.create_recordset(record_json_a, status=202) + created.append(client.wait_until_recordset_change_status(create_response, 'Complete')['recordSet']['id']) + + create_response = client.create_recordset(record_json_txt, status=202) + created.append(client.wait_until_recordset_change_status(create_response, 'Complete')['recordSet']['id']) + + list_results = client.list_recordsets(ok_zone['id'], status=200, start_from=None, max_items=1) + assert_that(list_results['recordSets'][0]['id'], is_(created[0])) + + list_results = client.list_recordsets(ok_zone['id'], status=200, start_from=list_results['nextId'], max_items=1) + assert_that(list_results['recordSets'][0]['id'], is_(created[1])) + + finally: + for id in created: + client.delete_recordset(ok_zone['id'], id, status=202) + client.wait_until_recordset_deleted(ok_zone['id'], id) + + def test_list_recordsets_with_record_name_filter_all(rs_fixture): """ Test listing all recordsets whose name contains a substring, all recordsets have substring 'list' in name diff --git a/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlRecordSetRepositoryIntegrationSpec.scala b/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlRecordSetRepositoryIntegrationSpec.scala index a280b6c9a..8d49383d8 100644 --- a/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlRecordSetRepositoryIntegrationSpec.scala +++ b/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlRecordSetRepositoryIntegrationSpec.scala @@ -24,6 +24,7 @@ import scalikejdbc.DB import vinyldns.core.domain.record._ import vinyldns.core.domain.zone.Zone import vinyldns.mysql.TestMySqlInstance +import vinyldns.mysql.repository.MySqlRecordSetRepository.PagingKey class MySqlRecordSetRepositoryIntegrationSpec extends WordSpec @@ -327,7 +328,7 @@ class MySqlRecordSetRepositoryIntegrationSpec "return record sets after the startFrom when set" in { // load 5, start after the 3rd, we should get back the last two val existing = insert(okZone, 5).map(_.recordSet).sortBy(_.name) - val startFrom = Some(existing(2).name) + val startFrom = Some(PagingKey.toNextId(existing(2))) val found = repo.listRecordSets(okZone.id, startFrom, None, None).unsafeRunSync() found.recordSets should contain theSameElementsInOrderAs existing.drop(3) @@ -335,7 +336,7 @@ class MySqlRecordSetRepositoryIntegrationSpec "return the record sets after the startFrom respecting maxItems" in { // load 5, start after the 2nd, take 2, we should get back the 3rd and 4th val existing = insert(okZone, 5).map(_.recordSet).sortBy(_.name) - val startFrom = Some(existing(1).name) + val startFrom = Some(PagingKey.toNextId(existing(1))) val found = repo.listRecordSets(okZone.id, startFrom, Some(2), None).unsafeRunSync() found.recordSets should contain theSameElementsInOrderAs existing.slice(2, 4) @@ -356,7 +357,7 @@ class MySqlRecordSetRepositoryIntegrationSpec val changes = newRecordSets.map(makeTestAddChange(_, okZone)) insert(changes) - val startFrom = Some(newRecordSets(1).name) + val startFrom = Some(PagingKey.toNextId(newRecordSets(1))) val found = repo.listRecordSets(okZone.id, startFrom, Some(3), Some("*z*")).unsafeRunSync() found.recordSets.map(_.name) should contain theSameElementsInOrderAs expectedNames } @@ -415,16 +416,42 @@ class MySqlRecordSetRepositoryIntegrationSpec val existing = insert(okZone, 5).map(_.recordSet).sortBy(_.name) val page1 = repo.listRecordSets(okZone.id, None, Some(2), None).unsafeRunSync() page1.recordSets should contain theSameElementsInOrderAs existing.slice(0, 2) - page1.nextId shouldBe Some(page1.recordSets(1).name) + page1.nextId shouldBe Some(PagingKey.toNextId(page1.recordSets(1))) val page2 = repo.listRecordSets(okZone.id, page1.nextId, Some(2), None).unsafeRunSync() page2.recordSets should contain theSameElementsInOrderAs existing.slice(2, 4) - page2.nextId shouldBe Some(page2.recordSets(1).name) + page2.nextId shouldBe Some(PagingKey.toNextId(page2.recordSets(1))) val page3 = repo.listRecordSets(okZone.id, page2.nextId, Some(2), None).unsafeRunSync() page3.recordSets should contain theSameElementsInOrderAs existing.slice(4, 5) page3.nextId shouldBe None } + + "page properly when records have the same name" in { + val changes = generateInserts(okZone, 5) + val editedChanges = List( + changes(0).copy(recordSet = aaaa.copy(zoneId = okZone.id, name = "a-duplicate")), + changes(2).copy(recordSet = cname.copy(zoneId = okZone.id, name = "a-duplicate")), + changes(4).copy(recordSet = ns.copy(zoneId = okZone.id, name = "a-duplicate")), + changes(1).copy(recordSet = changes(1).recordSet.copy(name = "b-unique")), + changes(3).copy(recordSet = changes(3).recordSet.copy(name = "c-unqiue")) + ) + + insert(editedChanges) + val existing = editedChanges.map(_.recordSet) + + val page1 = repo.listRecordSets(okZone.id, None, Some(2), None).unsafeRunSync() + page1.recordSets should contain theSameElementsInOrderAs List(existing(0), existing(1)) + page1.nextId shouldBe Some(PagingKey.toNextId(page1.recordSets.last)) + + val page2 = repo.listRecordSets(okZone.id, page1.nextId, Some(2), None).unsafeRunSync() + page2.recordSets should contain theSameElementsInOrderAs List(existing(2), existing(3)) + page2.nextId shouldBe Some(PagingKey.toNextId(page2.recordSets.last)) + + val page3 = repo.listRecordSets(okZone.id, page2.nextId, Some(2), None).unsafeRunSync() + page3.recordSets should contain theSameElementsInOrderAs List(existing(4)) + page3.nextId shouldBe None + } } "get record sets by name and type" should { "return a record set when there is a match" in { diff --git a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlRecordSetRepository.scala b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlRecordSetRepository.scala index 2fe6a0222..a9e71dd7f 100644 --- a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlRecordSetRepository.scala +++ b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlRecordSetRepository.scala @@ -26,6 +26,8 @@ import vinyldns.core.protobuf.ProtobufConversions import vinyldns.core.route.Monitored import vinyldns.proto.VinylDNSProto +import scala.util.Try + class MySqlRecordSetRepository extends RecordSetRepository with Monitored { import MySqlRecordSetRepository._ @@ -183,14 +185,19 @@ class MySqlRecordSetRepository extends RecordSetRepository with Monitored { monitor("repo.RecordSet.listRecordSets") { IO { DB.readOnly { implicit s => + val pagingKey = PagingKey(startFrom) + // make sure we sort ascending, so we can do the correct comparison later - val opts = (startFrom.as("AND name > {startFrom}") ++ - recordNameFilter.as("AND name LIKE {nameFilter}") ++ - Some("ORDER BY name ASC") ++ - maxItems.as("LIMIT {maxItems}")).toList.mkString(" ") + val opts = + (pagingKey.as( + "AND ((name >= {startFromName} AND type > {startFromType}) OR name > {startFromName})") ++ + recordNameFilter.as("AND name LIKE {nameFilter}") ++ + Some("ORDER BY name ASC, type ASC") ++ + maxItems.as("LIMIT {maxItems}")).toList.mkString(" ") val params = (Some('zoneId -> zoneId) ++ - startFrom.map(n => 'startFrom -> n) ++ + pagingKey.map(pk => 'startFromName -> pk.recordName) ++ + pagingKey.map(pk => 'startFromType -> pk.recordType) ++ recordNameFilter.map(f => 'nameFilter -> f.replace('*', '%')) ++ maxItems.map(m => 'maxItems -> m)).toSeq @@ -205,7 +212,9 @@ class MySqlRecordSetRepository extends RecordSetRepository with Monitored { // if size of results is less than the number returned, we don't have a next id // if maxItems is None, we don't have a next id val nextId = - maxItems.filter(_ == results.size).flatMap(_ => results.lastOption.map(_.name)) + maxItems + .filter(_ == results.size) + .flatMap(_ => results.lastOption.map(PagingKey.toNextId)) ListRecordSetResults( recordSets = results, @@ -357,4 +366,26 @@ object MySqlRecordSetRepository extends ProtobufConversions { if (absoluteRecordSetName.equals(absoluteZoneName)) absoluteZoneName else absoluteRecordSetName + absoluteZoneName } + + case class PagingKey(recordName: String, recordType: Int) + + object PagingKey { + val delimiterRegex = "\\.\\.\\.\\." + val delimiter = "...." + + def apply(startFrom: Option[String]): Option[PagingKey] = + for { + sf <- startFrom + tokens = sf.split(delimiterRegex) + recordName <- tokens.headOption + recordType <- Try(tokens(1).toInt).toOption + } yield PagingKey(recordName, recordType) + + def toNextId(last: RecordSet): String = { + val nextIdName = last.name + val nextIdType = MySqlRecordSetRepository.fromRecordType(last.typ) + + s"$nextIdName$delimiter$nextIdType" + } + } } diff --git a/modules/mysql/src/test/scala/vinyldns/mysql/repository/MySqlRecordSetRepositorySpec.scala b/modules/mysql/src/test/scala/vinyldns/mysql/repository/MySqlRecordSetRepositorySpec.scala index ccdccad88..e9304a08a 100644 --- a/modules/mysql/src/test/scala/vinyldns/mysql/repository/MySqlRecordSetRepositorySpec.scala +++ b/modules/mysql/src/test/scala/vinyldns/mysql/repository/MySqlRecordSetRepositorySpec.scala @@ -17,6 +17,7 @@ package vinyldns.mysql.repository import org.scalatest.{Matchers, WordSpec} import vinyldns.core.domain.record.RecordType +import vinyldns.core.TestRecordSetData.aaaa class MySqlRecordSetRepositorySpec extends WordSpec with Matchers { import MySqlRecordSetRepository._ @@ -70,4 +71,43 @@ class MySqlRecordSetRepositorySpec extends WordSpec with Matchers { toFQDN(zoneNameWithDot, recordNameWithDot) shouldBe expected } } + + "PagingKey.fromStartFrom" should { + "return None if startFrom is None" in { + PagingKey(None) shouldBe None + } + + "return None if startFrom is malformed" in { + val empty = "" + val noDelimiter = "nodelim" + val justDelimiter = s"${PagingKey.delimiter}" + val noType = s"name${PagingKey.delimiter}" + + PagingKey(Some(empty)) shouldBe None + PagingKey(Some(noDelimiter)) shouldBe None + PagingKey(Some(justDelimiter)) shouldBe None + PagingKey(Some(noType)) shouldBe None + } + + "return None if type is not an Int" in { + val startFrom = s"name${PagingKey.delimiter}notNumber" + PagingKey(Some(startFrom)) shouldBe None + } + + "return correct PagingKey" in { + val expected = PagingKey("name", 5) + val startFrom = s"${expected.recordName}${PagingKey.delimiter}${expected.recordType}" + PagingKey(Some(startFrom)) shouldBe Some(expected) + } + } + + "PagingKey.toNextId" should { + "return correct NextId" in { + val expectedName = "name" + val expectedType = MySqlRecordSetRepository.fromRecordType(RecordType.CNAME) + val last = aaaa.copy(name = expectedName, typ = RecordType.CNAME) + + PagingKey.toNextId(last) shouldBe s"$expectedName${PagingKey.delimiter}$expectedType" + } + } }