mirror of
https://github.com/VinylDNS/vinyldns
synced 2025-08-31 14:25:30 +00:00
add locked user account with auth check to api (#172)
add locked user account with auth check to api
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
from utils import *
|
||||
from hamcrest import *
|
||||
from vinyldns_python import VinylDNSClient
|
||||
from dns.resolver import *
|
||||
from vinyldns_context import VinylDNSTestContext
|
||||
|
||||
|
||||
def test_request_fails_when_user_account_is_locked():
|
||||
"""
|
||||
Test request fails with Forbidden (403) when user account is locked
|
||||
"""
|
||||
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'lockedAccessKey', 'lockedSecretKey')
|
||||
client.list_batch_change_summaries(status=403)
|
||||
|
||||
def test_request_fails_when_user_is_not_found():
|
||||
"""
|
||||
Test request fails with Unauthorized (401) when user account is not found
|
||||
"""
|
||||
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'unknownAccessKey', 'anyAccessSecretKey')
|
||||
|
||||
client.list_batch_change_summaries(status=401)
|
||||
|
||||
def test_request_succeeds_when_user_is_found_and_not_locked():
|
||||
"""
|
||||
Test request success with Success (200) when user account is found and not locked
|
||||
"""
|
||||
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'okAccessKey', 'okSecretKey')
|
||||
|
||||
client.list_batch_change_summaries(status=200)
|
@@ -93,7 +93,7 @@ def test_list_group_members_start_from(shared_zone_test_context):
|
||||
|
||||
# members has one more because admins are added as members
|
||||
assert_that(result['members'], has_length(len(members) + 1))
|
||||
assert_that(result['members'], has_item({ 'id': 'ok'}))
|
||||
assert_that(result['members'], has_item({ 'lockStatus': 'Unlocked', 'id': 'ok'}))
|
||||
result_member_ids = map(lambda member: member['id'], result['members'])
|
||||
for user in members:
|
||||
assert_that(result_member_ids, has_item(user['id']))
|
||||
|
@@ -101,7 +101,7 @@ class JdbcZoneRepositoryIntegrationSpec
|
||||
if (num == 1) z.addACLRule(dummyAclRule) else z
|
||||
}
|
||||
|
||||
private val superUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq())
|
||||
private val jdbcSuperUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq())
|
||||
|
||||
private def testZone(name: String, adminGroupId: String = testZoneAdminGroupId) =
|
||||
okZone.copy(name = name, id = UUID.randomUUID().toString, adminGroupId = adminGroupId)
|
||||
@@ -410,7 +410,7 @@ class JdbcZoneRepositoryIntegrationSpec
|
||||
val f =
|
||||
for {
|
||||
_ <- saveZones(testZones)
|
||||
retrieved <- repo.listZones(superUserAuth)
|
||||
retrieved <- repo.listZones(jdbcSuperUserAuth)
|
||||
} yield retrieved
|
||||
|
||||
whenReady(f.unsafeToFuture(), timeout) { retrieved =>
|
||||
@@ -431,7 +431,7 @@ class JdbcZoneRepositoryIntegrationSpec
|
||||
val f =
|
||||
for {
|
||||
_ <- saveZones(testZones)
|
||||
retrieved <- repo.listZones(superUserAuth, zoneNameFilter = Some("system"))
|
||||
retrieved <- repo.listZones(jdbcSuperUserAuth, zoneNameFilter = Some("system"))
|
||||
} yield retrieved
|
||||
|
||||
whenReady(f.unsafeToFuture(), timeout) { retrieved =>
|
||||
@@ -471,19 +471,19 @@ class JdbcZoneRepositoryIntegrationSpec
|
||||
|
||||
whenReady(saveZones(testZones).unsafeToFuture(), timeout) { _ =>
|
||||
whenReady(
|
||||
repo.listZones(superUserAuth, offset = None, pageSize = 4).unsafeToFuture(),
|
||||
repo.listZones(jdbcSuperUserAuth, offset = None, pageSize = 4).unsafeToFuture(),
|
||||
timeout) { firstPage =>
|
||||
(firstPage should contain).theSameElementsInOrderAs(expectedFirstPage)
|
||||
}
|
||||
|
||||
whenReady(
|
||||
repo.listZones(superUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(),
|
||||
repo.listZones(jdbcSuperUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(),
|
||||
timeout) { secondPage =>
|
||||
(secondPage should contain).theSameElementsInOrderAs(expectedSecondPage)
|
||||
}
|
||||
|
||||
whenReady(
|
||||
repo.listZones(superUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(),
|
||||
repo.listZones(jdbcSuperUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(),
|
||||
timeout) { thirdPage =>
|
||||
(thirdPage should contain).theSameElementsInOrderAs(expectedThirdPage)
|
||||
}
|
||||
|
@@ -21,6 +21,7 @@ import java.util.UUID
|
||||
import org.joda.time.DateTime
|
||||
import vinyldns.core.domain.membership.GroupChangeType.GroupChangeType
|
||||
import vinyldns.core.domain.membership.GroupStatus.GroupStatus
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
import vinyldns.core.domain.membership._
|
||||
|
||||
/* This is the new View model for Groups, do not surface the Group model directly any more */
|
||||
@@ -73,7 +74,8 @@ case class UserInfo(
|
||||
firstName: Option[String] = None,
|
||||
lastName: Option[String] = None,
|
||||
email: Option[String] = None,
|
||||
created: Option[DateTime] = None
|
||||
created: Option[DateTime] = None,
|
||||
lockStatus: LockStatus = LockStatus.Unlocked
|
||||
)
|
||||
object UserInfo {
|
||||
def apply(user: User): UserInfo =
|
||||
@@ -83,7 +85,8 @@ object UserInfo {
|
||||
firstName = user.firstName,
|
||||
lastName = user.lastName,
|
||||
email = user.email,
|
||||
created = Some(user.created)
|
||||
created = Some(user.created),
|
||||
lockStatus = user.lockStatus
|
||||
)
|
||||
}
|
||||
|
||||
|
@@ -19,6 +19,7 @@ package vinyldns.api.domain.membership
|
||||
import cats.implicits._
|
||||
import vinyldns.api.Interfaces._
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
import vinyldns.core.domain.zone.ZoneRepository
|
||||
import vinyldns.core.domain.membership._
|
||||
|
||||
@@ -55,7 +56,7 @@ class MembershipService(
|
||||
for {
|
||||
existingGroup <- getExistingGroup(groupId)
|
||||
newGroup = existingGroup.withUpdates(name, email, description, memberIds, adminUserIds)
|
||||
_ <- isAdmin(existingGroup, authPrincipal).toResult
|
||||
_ <- isGroupAdmin(existingGroup, authPrincipal).toResult
|
||||
addedMembers = newGroup.memberIds.diff(existingGroup.memberIds)
|
||||
removedMembers = existingGroup.memberIds.diff(newGroup.memberIds)
|
||||
_ <- hasMembersAndAdmins(newGroup).toResult
|
||||
@@ -72,7 +73,7 @@ class MembershipService(
|
||||
def deleteGroup(groupId: String, authPrincipal: AuthPrincipal): Result[Group] =
|
||||
for {
|
||||
existingGroup <- getExistingGroup(groupId)
|
||||
_ <- isAdmin(existingGroup, authPrincipal).toResult
|
||||
_ <- isGroupAdmin(existingGroup, authPrincipal).toResult
|
||||
_ <- groupCanBeDeleted(existingGroup)
|
||||
_ <- groupChangeRepo
|
||||
.save(GroupChange.forDelete(existingGroup, authPrincipal))
|
||||
@@ -174,6 +175,12 @@ class MembershipService(
|
||||
.getUsers(userIds, startFrom, pageSize)
|
||||
.toResult[ListUsersResults]
|
||||
|
||||
def getExistingUser(userId: String): Result[User] =
|
||||
userRepo
|
||||
.getUser(userId)
|
||||
.orFail(UserNotFoundError(s"User with ID $userId was not found"))
|
||||
.toResult[User]
|
||||
|
||||
def getExistingGroup(groupId: String): Result[Group] =
|
||||
groupRepo
|
||||
.getGroup(groupId)
|
||||
@@ -222,4 +229,15 @@ class MembershipService(
|
||||
}
|
||||
}
|
||||
.toResult
|
||||
|
||||
def updateUserLockStatus(
|
||||
userId: String,
|
||||
lockStatus: LockStatus,
|
||||
authPrincipal: AuthPrincipal): Result[User] =
|
||||
for {
|
||||
_ <- isSuperAdmin(authPrincipal).toResult
|
||||
existingUser <- getExistingUser(userId)
|
||||
newUser = existingUser.updateUserLockStatus(lockStatus)
|
||||
_ <- userRepo.save(newUser).toResult[User]
|
||||
} yield newUser
|
||||
}
|
||||
|
@@ -18,6 +18,7 @@ package vinyldns.api.domain.membership
|
||||
|
||||
import vinyldns.api.Interfaces.Result
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
import vinyldns.core.domain.membership._
|
||||
|
||||
trait MembershipServiceAlgebra {
|
||||
@@ -56,4 +57,9 @@ trait MembershipServiceAlgebra {
|
||||
startFrom: Option[String],
|
||||
maxItems: Int,
|
||||
authPrincipal: AuthPrincipal): Result[ListGroupChangesResponse]
|
||||
|
||||
def updateUserLockStatus(
|
||||
userId: String,
|
||||
lockStatus: LockStatus,
|
||||
authPrincipal: AuthPrincipal): Result[User]
|
||||
}
|
||||
|
@@ -28,11 +28,16 @@ object MembershipValidations {
|
||||
group.memberIds.nonEmpty && group.adminUserIds.nonEmpty
|
||||
}
|
||||
|
||||
def isAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
|
||||
def isGroupAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
|
||||
ensuring(NotAuthorizedError("Not authorized")) {
|
||||
group.adminUserIds.contains(authPrincipal.userId) || authPrincipal.signedInUser.isSuper
|
||||
}
|
||||
|
||||
def isSuperAdmin(authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
|
||||
ensuring(NotAuthorizedError("Not authorized")) {
|
||||
authPrincipal.signedInUser.isSuper
|
||||
}
|
||||
|
||||
def canSeeGroup(groupId: String, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
|
||||
ensuring(NotAuthorizedError("Not authorized")) {
|
||||
authPrincipal.isAuthorized(groupId)
|
||||
|
@@ -51,7 +51,20 @@ object TestDataLoader {
|
||||
id = "dummy",
|
||||
created = DateTime.now.secondOfDay().roundFloorCopy(),
|
||||
accessKey = "dummyAccessKey",
|
||||
secretKey = "dummySecretKey")
|
||||
secretKey = "dummySecretKey"
|
||||
)
|
||||
final val lockedUser = User(
|
||||
userName = "locked",
|
||||
id = "locked",
|
||||
created = DateTime.now.secondOfDay().roundFloorCopy(),
|
||||
accessKey = "lockedAccessKey",
|
||||
secretKey = "lockedSecretKey",
|
||||
firstName = Some("Locked"),
|
||||
lastName = Some("User"),
|
||||
email = Some("testlocked@test.com"),
|
||||
isSuper = false,
|
||||
lockStatus = LockStatus.Locked
|
||||
)
|
||||
final val listOfDummyUsers: List[User] = List.range(0, 200).map { runner =>
|
||||
User(
|
||||
userName = "name-dummy%03d".format(runner),
|
||||
@@ -117,7 +130,7 @@ object TestDataLoader {
|
||||
)
|
||||
|
||||
def loadTestData(repository: UserRepository): IO[List[User]] =
|
||||
(testUser :: okUser :: dummyUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser ::
|
||||
(testUser :: okUser :: dummyUser :: lockedUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser ::
|
||||
listZeroBatchChangeSummariesUser :: zoneHistoryUser :: listOfDummyUsers).map { user =>
|
||||
val encrypted =
|
||||
if (VinylDNSConfig.encryptUserSecrets)
|
||||
|
@@ -23,7 +23,7 @@ import cats.implicits._
|
||||
import org.joda.time.DateTime
|
||||
import org.json4s._
|
||||
import vinyldns.api.domain.membership._
|
||||
import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus}
|
||||
import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus, LockStatus}
|
||||
|
||||
object MembershipJsonProtocol {
|
||||
final case class CreateGroupInput(
|
||||
@@ -52,6 +52,7 @@ trait MembershipJsonProtocol extends JsonValidation {
|
||||
GroupChangeInfoSerializer,
|
||||
CreateGroupInputSerializer,
|
||||
UpdateGroupInputSerializer,
|
||||
JsonEnumV(LockStatus),
|
||||
JsonEnumV(GroupStatus),
|
||||
JsonEnumV(GroupChangeType)
|
||||
)
|
||||
|
@@ -23,7 +23,7 @@ import vinyldns.api.domain.membership._
|
||||
import vinyldns.api.domain.zone.NotAuthorizedError
|
||||
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.domain.membership.Group
|
||||
import vinyldns.core.domain.membership.{Group, LockStatus}
|
||||
|
||||
trait MembershipRoute extends Directives {
|
||||
this: VinylDNSJsonProtocol with VinylDNSDirectives with JsonValidationRejection =>
|
||||
@@ -168,6 +168,18 @@ trait MembershipRoute extends Directives {
|
||||
}
|
||||
}
|
||||
}
|
||||
} ~
|
||||
(put & path("users" / Segment / "lock") & monitor("Endpoint.lockUser")) { id =>
|
||||
execute(membershipService.updateUserLockStatus(id, LockStatus.Locked, authPrincipal)) {
|
||||
user =>
|
||||
complete(StatusCodes.OK, UserInfo(user))
|
||||
}
|
||||
} ~
|
||||
(put & path("users" / Segment / "unlock") & monitor("Endpoint.unlockUser")) { id =>
|
||||
execute(membershipService.updateUserLockStatus(id, LockStatus.Unlocked, authPrincipal)) {
|
||||
user =>
|
||||
complete(StatusCodes.OK, UserInfo(user))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -17,8 +17,7 @@
|
||||
package vinyldns.api.route
|
||||
|
||||
import akka.http.scaladsl.model.HttpRequest
|
||||
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
|
||||
import akka.http.scaladsl.server.{AuthenticationFailedRejection, RequestContext}
|
||||
import akka.http.scaladsl.server.RequestContext
|
||||
import cats.effect._
|
||||
import cats.syntax.all._
|
||||
import vinyldns.api.VinylDNSConfig
|
||||
@@ -27,12 +26,14 @@ import vinyldns.api.domain.auth.{AuthPrincipalProvider, MembershipAuthPrincipalP
|
||||
import vinyldns.core.crypto.CryptoAlgebra
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.route.Monitored
|
||||
import vinyldns.core.domain.membership.LockStatus
|
||||
|
||||
import scala.util.matching.Regex
|
||||
|
||||
sealed abstract class VinylDNSAuthenticationError(msg: String) extends Throwable(msg)
|
||||
final case class AuthMissing(msg: String) extends VinylDNSAuthenticationError(msg)
|
||||
final case class AuthRejected(reason: String) extends VinylDNSAuthenticationError(reason)
|
||||
final case class AccountLocked(reason: String) extends VinylDNSAuthenticationError(reason)
|
||||
|
||||
trait VinylDNSAuthentication extends Monitored {
|
||||
val authenticator: Aws4Authenticator
|
||||
@@ -131,8 +132,14 @@ trait VinylDNSAuthentication extends Monitored {
|
||||
if (encryptionEnabled) crypto.decrypt(str) else str
|
||||
|
||||
def getAuthPrincipal(accessKey: String): IO[AuthPrincipal] =
|
||||
authPrincipalProvider.getAuthPrincipal(accessKey).map {
|
||||
_.getOrElse(throw AuthRejected(s"Account with accessKey $accessKey specified was not found"))
|
||||
authPrincipalProvider.getAuthPrincipal(accessKey).flatMap {
|
||||
case Some(ok) =>
|
||||
if (ok.signedInUser.lockStatus == LockStatus.Locked) {
|
||||
IO.raiseError(
|
||||
AccountLocked(s"Account with username ${ok.signedInUser.userName} is locked"))
|
||||
} else IO.pure(ok)
|
||||
case None =>
|
||||
IO.raiseError(AuthRejected(s"Account with accessKey $accessKey specified was not found"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,16 +148,14 @@ class VinylDNSAuthenticator(
|
||||
val authPrincipalProvider: AuthPrincipalProvider)
|
||||
extends VinylDNSAuthentication {
|
||||
|
||||
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] =
|
||||
authenticate(ctx, content).attempt.map {
|
||||
case Right(ok) => Right(ok)
|
||||
case Left(_: AuthMissing) =>
|
||||
Left(AuthenticationFailedRejection.CredentialsMissing)
|
||||
case Left(_: AuthRejected) =>
|
||||
Left(AuthenticationFailedRejection.CredentialsRejected)
|
||||
case Left(e: Throwable) =>
|
||||
// throw here as some unexpected exception occurred
|
||||
throw e
|
||||
def apply(
|
||||
ctx: RequestContext,
|
||||
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
|
||||
// Need to refactor authenticate to be an IO[Either[E, A]] instead of how it is implemented, for the time being...
|
||||
authenticate(ctx, content).attempt.flatMap {
|
||||
case Left(e: VinylDNSAuthenticationError) => IO.pure(Left(e))
|
||||
case Right(ok) => IO.pure(Right(ok))
|
||||
case Left(e) => IO.raiseError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,6 +164,8 @@ object VinylDNSAuthenticator {
|
||||
lazy val authPrincipalProvider = MembershipAuthPrincipalProvider()
|
||||
lazy val authenticator = new VinylDNSAuthenticator(aws4Authenticator, authPrincipalProvider)
|
||||
|
||||
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] =
|
||||
def apply(
|
||||
ctx: RequestContext,
|
||||
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
|
||||
authenticator.apply(ctx, content)
|
||||
}
|
||||
|
@@ -17,7 +17,6 @@
|
||||
package vinyldns.api.route
|
||||
|
||||
import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
|
||||
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
|
||||
import akka.http.scaladsl.server.RouteResult.{Complete, Rejected}
|
||||
import akka.http.scaladsl.server._
|
||||
import akka.http.scaladsl.server.directives.BasicDirectives
|
||||
@@ -42,7 +41,7 @@ trait VinylDNSDirectives extends Directives {
|
||||
*/
|
||||
def vinyldnsAuthenticator(
|
||||
ctx: RequestContext,
|
||||
content: String): IO[Either[Cause, AuthPrincipal]] =
|
||||
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
|
||||
VinylDNSAuthenticator(ctx, content)
|
||||
|
||||
def authenticate: Directive1[AuthPrincipal] =
|
||||
@@ -53,19 +52,27 @@ trait VinylDNSDirectives extends Directives {
|
||||
.flatMap {
|
||||
case Right(authPrincipal) ⇒
|
||||
provide(authPrincipal)
|
||||
case Left(cause) ⇒
|
||||
// we need to finish the result, rejections will proceed and ultimately
|
||||
// we can fail with a different rejection
|
||||
complete(
|
||||
HttpResponse(
|
||||
status = StatusCodes.Unauthorized,
|
||||
entity = HttpEntity(s"Authentication Failed: $cause")
|
||||
))
|
||||
case Left(e) ⇒
|
||||
complete(handleAuthenticateError(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def handleAuthenticateError(error: VinylDNSAuthenticationError): HttpResponse =
|
||||
error match {
|
||||
case AccountLocked(err) =>
|
||||
HttpResponse(
|
||||
status = StatusCodes.Forbidden,
|
||||
entity = HttpEntity(s"Authentication Failed: $err")
|
||||
)
|
||||
case e =>
|
||||
HttpResponse(
|
||||
status = StatusCodes.Unauthorized,
|
||||
entity = HttpEntity(s"Authentication Failed: ${e.getMessage}")
|
||||
)
|
||||
}
|
||||
|
||||
/* Adds monitoring to an Endpoint. The name will be surfaced in JMX */
|
||||
def monitor(name: String): Directive0 =
|
||||
extractExecutionContext.flatMap { implicit ec ⇒
|
||||
|
@@ -30,8 +30,9 @@ trait GroupTestData { this: Matchers =>
|
||||
|
||||
val okUser: User = TestDataLoader.okUser
|
||||
val dummyUser: User = TestDataLoader.dummyUser
|
||||
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers
|
||||
val lockedUser: User = TestDataLoader.lockedUser
|
||||
|
||||
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers
|
||||
val okUserInfo: UserInfo = UserInfo(okUser)
|
||||
val dummyUserInfo: UserInfo = UserInfo(dummyUser)
|
||||
|
||||
@@ -91,6 +92,7 @@ trait GroupTestData { this: Matchers =>
|
||||
val noGroupsUserAuth: AuthPrincipal = AuthPrincipal(okUser, Seq())
|
||||
val deletedGroupAuth: AuthPrincipal = AuthPrincipal(okUser, Seq(deletedGroup.id))
|
||||
val dummyUserAuth: AuthPrincipal = AuthPrincipal(dummyUser, Seq(dummyGroup.id))
|
||||
val lockedUserAuth: AuthPrincipal = AuthPrincipal(lockedUser, Seq())
|
||||
val listOfDummyGroupsAuth: AuthPrincipal = AuthPrincipal(dummyUser, listOfDummyGroups.map(_.id))
|
||||
|
||||
val memberOkZoneAuthorized: Zone = Zone(
|
||||
|
@@ -39,6 +39,7 @@ trait VinylDNSTestData {
|
||||
created = DateTime.now.secondOfDay().roundFloorCopy())
|
||||
val okAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.okUser, Seq(grp.id))
|
||||
val notAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.dummyUser, Seq.empty)
|
||||
val lockedAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.lockedUser, Seq.empty)
|
||||
|
||||
val testConnection: Option[ZoneConnection] = Some(
|
||||
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1"))
|
||||
|
@@ -23,7 +23,7 @@ import org.mockito.Mockito._
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
|
||||
import vinyldns.api.Interfaces._
|
||||
import vinyldns.api.{GroupTestData, ResultHelpers}
|
||||
import vinyldns.api.{GroupTestData, ResultHelpers, VinylDNSTestData}
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.domain.zone.{ZoneRepository, _}
|
||||
import cats.effect._
|
||||
@@ -37,6 +37,7 @@ class MembershipServiceSpec
|
||||
with BeforeAndAfterEach
|
||||
with ResultHelpers
|
||||
with GroupTestData
|
||||
with VinylDNSTestData
|
||||
with EitherMatchers {
|
||||
|
||||
private val mockGroupRepo = mock[GroupRepository]
|
||||
@@ -750,5 +751,72 @@ class MembershipServiceSpec
|
||||
error shouldBe a[InvalidGroupRequestError]
|
||||
}
|
||||
}
|
||||
|
||||
"updateUserLockStatus" should {
|
||||
"save the update and lock the user account" in {
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(IO.pure(Some(okUser))).when(mockUserRepo).getUser(okUser.id)
|
||||
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
|
||||
|
||||
underTest
|
||||
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
|
||||
.value
|
||||
.unsafeRunSync()
|
||||
|
||||
val userCaptor = ArgumentCaptor.forClass(classOf[User])
|
||||
|
||||
verify(mockUserRepo).save(userCaptor.capture())
|
||||
|
||||
val savedUser = userCaptor.getValue
|
||||
savedUser.lockStatus shouldBe LockStatus.Locked
|
||||
savedUser.id shouldBe okUser.id
|
||||
}
|
||||
|
||||
"save the update and unlock the user account" in {
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(IO.pure(Some(lockedUser))).when(mockUserRepo).getUser(lockedUser.id)
|
||||
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
|
||||
|
||||
underTest
|
||||
.updateUserLockStatus(lockedUser.id, LockStatus.Unlocked, superUserAuth)
|
||||
.value
|
||||
.unsafeRunSync()
|
||||
|
||||
val userCaptor = ArgumentCaptor.forClass(classOf[User])
|
||||
|
||||
verify(mockUserRepo).save(userCaptor.capture())
|
||||
|
||||
val savedUser = userCaptor.getValue
|
||||
savedUser.lockStatus shouldBe LockStatus.Unlocked
|
||||
savedUser.id shouldBe lockedUser.id
|
||||
}
|
||||
|
||||
"return an error if the signed in user is not a super user" in {
|
||||
val error = leftResultOf(
|
||||
underTest
|
||||
.updateUserLockStatus(okUser.id, LockStatus.Locked, dummyUserAuth)
|
||||
.value)
|
||||
|
||||
error shouldBe a[NotAuthorizedError]
|
||||
}
|
||||
|
||||
"return an error if the requested user is not found" in {
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(IO.pure(None)).when(mockUserRepo).getUser(okUser.id)
|
||||
|
||||
val error = leftResultOf(
|
||||
underTest
|
||||
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
|
||||
.value)
|
||||
|
||||
error shouldBe a[UserNotFoundError]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -56,17 +56,17 @@ class MembershipValidationsSpec
|
||||
|
||||
"isAdmin" should {
|
||||
"return true when the user is in admin group" in {
|
||||
isAdmin(okGroup, okUserAuth) should be(right)
|
||||
isGroupAdmin(okGroup, okUserAuth) should be(right)
|
||||
}
|
||||
"return true when the user is a super user" in {
|
||||
val user = User("some", "new", "user", isSuper = true)
|
||||
val superAuth = AuthPrincipal(user, Seq())
|
||||
isAdmin(okGroup, superAuth) should be(right)
|
||||
isGroupAdmin(okGroup, superAuth) should be(right)
|
||||
}
|
||||
"return an error when the user has no access and is not super" in {
|
||||
val user = User("some", "new", "user")
|
||||
val nonSuperAuth = AuthPrincipal(user, Seq())
|
||||
val error = leftValue(isAdmin(okGroup, nonSuperAuth))
|
||||
val error = leftValue(isGroupAdmin(okGroup, nonSuperAuth))
|
||||
error shouldBe an[NotAuthorizedError]
|
||||
}
|
||||
}
|
||||
|
@@ -30,10 +30,11 @@ import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
|
||||
import vinyldns.api.Interfaces._
|
||||
import vinyldns.api.domain.membership._
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
import vinyldns.core.domain.membership.Group
|
||||
import vinyldns.core.domain.membership.{Group, LockStatus}
|
||||
import vinyldns.api.domain.zone.NotAuthorizedError
|
||||
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
|
||||
import vinyldns.api.{GroupTestData, VinylDNSTestData}
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
|
||||
class MembershipRoutingSpec
|
||||
extends WordSpec
|
||||
@@ -671,4 +672,64 @@ class MembershipRoutingSpec
|
||||
}
|
||||
}
|
||||
}
|
||||
"PUT update user lock status" should {
|
||||
"return a 200 response with the user locked" in {
|
||||
val updatedUser = okUser.copy(lockStatus = LockStatus.Locked)
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(result(updatedUser))
|
||||
.when(membershipService)
|
||||
.updateUserLockStatus("ok", LockStatus.Locked, superUserAuth)
|
||||
|
||||
Put("/users/ok/lock") ~> membershipRoute(superUserAuth) ~> check {
|
||||
status shouldBe StatusCodes.OK
|
||||
|
||||
val result = responseAs[UserInfo]
|
||||
|
||||
result.id shouldBe okUser.id
|
||||
result.lockStatus shouldBe LockStatus.Locked
|
||||
}
|
||||
}
|
||||
|
||||
"return a 200 response with the user unlocked" in {
|
||||
val updatedUser = lockedUser.copy(lockStatus = LockStatus.Unlocked)
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(result(updatedUser))
|
||||
.when(membershipService)
|
||||
.updateUserLockStatus("locked", LockStatus.Unlocked, superUserAuth)
|
||||
|
||||
Put("/users/locked/unlock") ~> membershipRoute(superUserAuth) ~> check {
|
||||
status shouldBe StatusCodes.OK
|
||||
|
||||
val result = responseAs[UserInfo]
|
||||
|
||||
result.id shouldBe lockedUser.id
|
||||
result.lockStatus shouldBe LockStatus.Unlocked
|
||||
}
|
||||
}
|
||||
|
||||
"return a 404 Not Found when the user is not found" in {
|
||||
val superUserAuth = okAuth.copy(
|
||||
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
|
||||
memberGroupIds = Seq.empty)
|
||||
doReturn(result(UserNotFoundError("fail")))
|
||||
.when(membershipService)
|
||||
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
|
||||
Put("/users/notFound/lock") ~> membershipRoute(superUserAuth) ~> check {
|
||||
status shouldBe StatusCodes.NotFound
|
||||
}
|
||||
}
|
||||
|
||||
"return a 403 Forbidden when not authorized" in {
|
||||
doReturn(result(NotAuthorizedError("fail")))
|
||||
.when(membershipService)
|
||||
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
|
||||
Put("/users/forbidden/lock") ~> membershipRoute(okGroupAuth) ~> check {
|
||||
status shouldBe StatusCodes.Forbidden
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -17,7 +17,6 @@
|
||||
package vinyldns.api.route
|
||||
|
||||
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest, StatusCodes}
|
||||
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
|
||||
import akka.http.scaladsl.server.{Directives, RequestContext, Route}
|
||||
import akka.http.scaladsl.testkit.ScalatestRouteTest
|
||||
import cats.effect._
|
||||
@@ -485,7 +484,7 @@ class RecordSetRoutingSpec
|
||||
|
||||
override def vinyldnsAuthenticator(
|
||||
ctx: RequestContext,
|
||||
content: String): IO[Either[Cause, AuthPrincipal]] =
|
||||
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
|
||||
IO.pure(Right(okAuth))
|
||||
|
||||
private def rsJson(recordSet: RecordSet): String =
|
||||
|
@@ -17,11 +17,6 @@
|
||||
package vinyldns.api.route
|
||||
|
||||
import akka.http.scaladsl.model.{HttpHeader, HttpRequest}
|
||||
import akka.http.scaladsl.server.AuthenticationFailedRejection.{
|
||||
Cause,
|
||||
CredentialsMissing,
|
||||
CredentialsRejected
|
||||
}
|
||||
import akka.http.scaladsl.server.RequestContext
|
||||
import cats.effect._
|
||||
import org.mockito.Matchers._
|
||||
@@ -29,15 +24,13 @@ import org.mockito.Mockito._
|
||||
import org.scalatest.mockito.MockitoSugar
|
||||
import org.scalatest.{Matchers, WordSpec}
|
||||
import vinyldns.api.domain.auth.AuthPrincipalProvider
|
||||
import vinyldns.api.{GroupTestData, ResultHelpers}
|
||||
import vinyldns.api.{GroupTestData}
|
||||
import vinyldns.core.crypto.CryptoAlgebra
|
||||
import vinyldns.core.domain.auth.AuthPrincipal
|
||||
|
||||
class VinylDNSAuthenticatorSpec
|
||||
extends WordSpec
|
||||
with Matchers
|
||||
with MockitoSugar
|
||||
with ResultHelpers
|
||||
with GroupTestData {
|
||||
private val mockAuthenticator = mock[Aws4Authenticator]
|
||||
private val mockAuthPrincipalProvider = mock[AuthPrincipalProvider]
|
||||
@@ -87,8 +80,7 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthenticator)
|
||||
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Right(okUserAuth)
|
||||
}
|
||||
"fail if missing Authorization header" in {
|
||||
@@ -109,9 +101,8 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthenticator)
|
||||
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsMissing)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthMissing("Authorization header not found"))
|
||||
}
|
||||
"fail if Authorization header can not be parsed" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
@@ -125,9 +116,8 @@ class VinylDNSAuthenticatorSpec
|
||||
val context: RequestContext = mock[RequestContext]
|
||||
doReturn(httpRequest).when(context).request
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsRejected)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthRejected("Authorization header could not be parsed"))
|
||||
}
|
||||
"fail if the access key is missing" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
@@ -149,9 +139,8 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthenticator)
|
||||
.extractAccessKey(any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsMissing)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthMissing("accessKey not found"))
|
||||
}
|
||||
"fail if the access key can not be retrieved" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
@@ -173,9 +162,34 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthenticator)
|
||||
.extractAccessKey(any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsRejected)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthRejected("Invalid authorization header"))
|
||||
}
|
||||
"fail if the user is locked" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
doReturn("Authorization").when(fakeHttpHeader).name
|
||||
|
||||
val header = "AWS4-HMAC-SHA256" +
|
||||
" Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request," +
|
||||
" SignedHeaders=host;range;x-amz-date," +
|
||||
" Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
doReturn(header).when(fakeHttpHeader).value
|
||||
|
||||
val httpRequest: HttpRequest = HttpRequest().withHeaders(List(fakeHttpHeader))
|
||||
|
||||
val context: RequestContext = mock[RequestContext]
|
||||
doReturn(httpRequest).when(context).request
|
||||
|
||||
doReturn(lockedUser.accessKey)
|
||||
.when(mockAuthenticator)
|
||||
.extractAccessKey(any[String])
|
||||
|
||||
doReturn(IO.pure(Some(lockedUserAuth)))
|
||||
.when(mockAuthPrincipalProvider)
|
||||
.getAuthPrincipal(any[String])
|
||||
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AccountLocked("Account with username locked is locked"))
|
||||
}
|
||||
"fail if the user can not be found" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
@@ -192,7 +206,7 @@ class VinylDNSAuthenticatorSpec
|
||||
val context: RequestContext = mock[RequestContext]
|
||||
doReturn(httpRequest).when(context).request
|
||||
|
||||
doReturn(okUser.accessKey)
|
||||
doReturn("fakeKey")
|
||||
.when(mockAuthenticator)
|
||||
.extractAccessKey(any[String])
|
||||
|
||||
@@ -201,9 +215,8 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthPrincipalProvider)
|
||||
.getAuthPrincipal(any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsRejected)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthRejected("Account with accessKey fakeKey specified was not found"))
|
||||
}
|
||||
"fail if signatures can not be validated" in {
|
||||
val fakeHttpHeader = mock[HttpHeader]
|
||||
@@ -233,9 +246,8 @@ class VinylDNSAuthenticatorSpec
|
||||
.when(mockAuthenticator)
|
||||
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
|
||||
|
||||
val result =
|
||||
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
|
||||
result shouldBe Left(CredentialsRejected)
|
||||
val result = underTest.apply(context, "").unsafeRunSync()
|
||||
result shouldBe Left(AuthRejected("Request signature could not be validated"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -18,7 +18,7 @@ package vinyldns.api.route
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import akka.http.scaladsl.model.{HttpResponse, StatusCodes}
|
||||
import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
|
||||
import akka.http.scaladsl.server.{Directives, Route}
|
||||
import akka.http.scaladsl.testkit.ScalatestRouteTest
|
||||
import nl.grons.metrics.scala.{Histogram, Meter}
|
||||
@@ -66,6 +66,26 @@ class VinylDNSDirectivesSpec
|
||||
override def beforeEach(): Unit =
|
||||
reset(mockLatency, mockErrors)
|
||||
|
||||
".handleAuthenticateError" should {
|
||||
"respond with Forbidden status if account is locked" in {
|
||||
val trythis = handleAuthenticateError(AccountLocked("error"))
|
||||
|
||||
trythis shouldBe HttpResponse(
|
||||
status = StatusCodes.Forbidden,
|
||||
entity = HttpEntity(s"Authentication Failed: error")
|
||||
)
|
||||
}
|
||||
|
||||
"respond with Unauthorized status for other authentication errors" in {
|
||||
val trythis = handleAuthenticateError(AuthRejected("error"))
|
||||
|
||||
trythis shouldBe HttpResponse(
|
||||
status = StatusCodes.Unauthorized,
|
||||
entity = HttpEntity(s"Authentication Failed: error")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
"The monitor directive" should {
|
||||
"record when completing an HttpResponse normally" in {
|
||||
Get("/test") ~> testRoute ~> check {
|
||||
|
@@ -19,7 +19,6 @@ package vinyldns.api.route
|
||||
import akka.actor.ActorSystem
|
||||
import akka.http.scaladsl.model.StatusCodes._
|
||||
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest}
|
||||
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
|
||||
import akka.http.scaladsl.server.{Directives, RequestContext, Route}
|
||||
import akka.http.scaladsl.testkit.ScalatestRouteTest
|
||||
import cats.effect._
|
||||
@@ -332,7 +331,7 @@ class ZoneRoutingSpec
|
||||
|
||||
override def vinyldnsAuthenticator(
|
||||
ctx: RequestContext,
|
||||
content: String): IO[Either[Cause, AuthPrincipal]] =
|
||||
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
|
||||
IO.pure(Right(okAuth))
|
||||
|
||||
def zoneJson(name: String, email: String): String =
|
||||
|
@@ -19,6 +19,12 @@ package vinyldns.core.domain.membership
|
||||
import java.util.UUID
|
||||
|
||||
import org.joda.time.DateTime
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
|
||||
object LockStatus extends Enumeration {
|
||||
type LockStatus = Value
|
||||
val Locked, Unlocked = Value
|
||||
}
|
||||
|
||||
case class User(
|
||||
userName: String,
|
||||
@@ -29,5 +35,10 @@ case class User(
|
||||
email: Option[String] = None,
|
||||
created: DateTime = DateTime.now,
|
||||
id: String = UUID.randomUUID().toString,
|
||||
isSuper: Boolean = false
|
||||
)
|
||||
isSuper: Boolean = false,
|
||||
lockStatus: LockStatus = LockStatus.Unlocked
|
||||
) {
|
||||
|
||||
def updateUserLockStatus(lockStatus: LockStatus): User =
|
||||
this.copy(lockStatus = lockStatus)
|
||||
}
|
||||
|
@@ -20,8 +20,7 @@ import cats.implicits._
|
||||
import com.amazonaws.services.dynamodbv2.model.DeleteTableRequest
|
||||
import com.typesafe.config.ConfigFactory
|
||||
import vinyldns.core.crypto.NoOpCrypto
|
||||
import vinyldns.core.domain.membership.User
|
||||
|
||||
import vinyldns.core.domain.membership.{User, LockStatus}
|
||||
import scala.concurrent.duration._
|
||||
|
||||
class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
|
||||
@@ -135,5 +134,24 @@ class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
|
||||
result shouldBe Some(testUser)
|
||||
result.get.isSuper shouldBe false
|
||||
}
|
||||
"returns the locked flag when true" in {
|
||||
val testUser = User(
|
||||
userName = "testSuper",
|
||||
accessKey = "testSuper",
|
||||
secretKey = "testUser",
|
||||
lockStatus = LockStatus.Locked)
|
||||
|
||||
val saved = repo.save(testUser).unsafeRunSync()
|
||||
val result = repo.getUser(saved.id).unsafeRunSync()
|
||||
|
||||
result shouldBe Some(testUser)
|
||||
result.get.lockStatus shouldBe LockStatus.Locked
|
||||
}
|
||||
"returns the locked flag when false" in {
|
||||
val f = repo.getUserByAccessKey(users.head.accessKey).unsafeRunSync()
|
||||
|
||||
f shouldBe Some(users.head)
|
||||
f.get.lockStatus shouldBe LockStatus.Unlocked
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -26,10 +26,12 @@ import com.amazonaws.services.dynamodbv2.model._
|
||||
import org.joda.time.DateTime
|
||||
import org.slf4j.{Logger, LoggerFactory}
|
||||
import vinyldns.core.crypto.CryptoAlgebra
|
||||
import vinyldns.core.domain.membership.{ListUsersResults, User, UserRepository}
|
||||
import vinyldns.core.domain.membership.LockStatus.LockStatus
|
||||
import vinyldns.core.domain.membership.{ListUsersResults, LockStatus, User, UserRepository}
|
||||
import vinyldns.core.route.Monitored
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.util.Try
|
||||
|
||||
object DynamoDBUserRepository {
|
||||
|
||||
@@ -42,6 +44,7 @@ object DynamoDBUserRepository {
|
||||
private[repository] val ACCESS_KEY = "accesskey"
|
||||
private[repository] val SECRET_KEY = "secretkey"
|
||||
private[repository] val IS_SUPER = "super"
|
||||
private[repository] val LOCK_STATUS = "lockstatus"
|
||||
private[repository] val USER_NAME_INDEX_NAME = "username_index"
|
||||
private[repository] val ACCESS_KEY_INDEX_NAME = "access_key_index"
|
||||
|
||||
@@ -97,6 +100,7 @@ object DynamoDBUserRepository {
|
||||
item.put(ACCESS_KEY, new AttributeValue(user.accessKey))
|
||||
item.put(SECRET_KEY, new AttributeValue(crypto.encrypt(user.secretKey)))
|
||||
item.put(IS_SUPER, new AttributeValue().withBOOL(user.isSuper))
|
||||
item.put(LOCK_STATUS, new AttributeValue(user.lockStatus.toString))
|
||||
|
||||
val firstName =
|
||||
user.firstName.map(new AttributeValue(_)).getOrElse(new AttributeValue().withNULL(true))
|
||||
@@ -110,6 +114,12 @@ object DynamoDBUserRepository {
|
||||
}
|
||||
|
||||
def fromItem(item: java.util.Map[String, AttributeValue]): IO[User] = IO {
|
||||
def userStatus(str: String): LockStatus = Try(LockStatus.withName(str)).getOrElse {
|
||||
val log: Logger = LoggerFactory.getLogger(classOf[DynamoDBUserRepository])
|
||||
log.error(s"Invalid locked status value '$str'; defaulting to unlocked")
|
||||
LockStatus.Unlocked
|
||||
}
|
||||
|
||||
User(
|
||||
id = item.get(USER_ID).getS,
|
||||
userName = item.get(USER_NAME).getS,
|
||||
@@ -119,7 +129,8 @@ object DynamoDBUserRepository {
|
||||
firstName = Option(item.get(FIRST_NAME)).flatMap(fn => Option(fn.getS)),
|
||||
lastName = Option(item.get(LAST_NAME)).flatMap(ln => Option(ln.getS)),
|
||||
email = Option(item.get(EMAIL)).flatMap(e => Option(e.getS)),
|
||||
isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL
|
||||
isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL,
|
||||
lockStatus = userStatus(item.get(LOCK_STATUS).getS)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@@ -31,6 +31,7 @@ import scala.collection.JavaConverters._
|
||||
import cats.effect._
|
||||
import com.typesafe.config.ConfigFactory
|
||||
import vinyldns.core.crypto.{CryptoAlgebra, NoOpCrypto}
|
||||
import vinyldns.core.domain.membership.LockStatus
|
||||
import vinyldns.dynamodb.DynamoTestConfig
|
||||
|
||||
class DynamoDBUserRepositorySpec
|
||||
@@ -72,6 +73,7 @@ class DynamoDBUserRepositorySpec
|
||||
items.get(LAST_NAME).getS shouldBe okUser.lastName.get
|
||||
items.get(EMAIL).getS shouldBe okUser.email.get
|
||||
items.get(CREATED).getN shouldBe okUser.created.getMillis.toString
|
||||
items.get(LOCK_STATUS).getS shouldBe okUser.lockStatus.toString
|
||||
}
|
||||
"set the first name to null if it is not present" in {
|
||||
val emptyFirstName = okUser.copy(firstName = None)
|
||||
@@ -131,6 +133,7 @@ class DynamoDBUserRepositorySpec
|
||||
item.put(CREATED, new AttributeValue().withN("0"))
|
||||
item.put(ACCESS_KEY, new AttributeValue("accessKey"))
|
||||
item.put(SECRET_KEY, new AttributeValue("secretkey"))
|
||||
item.put(LOCK_STATUS, new AttributeValue("lockstatus"))
|
||||
val user = fromItem(item).unsafeRunSync()
|
||||
|
||||
user.firstName shouldBe None
|
||||
@@ -151,10 +154,24 @@ class DynamoDBUserRepositorySpec
|
||||
item.put(CREATED, new AttributeValue().withN("0"))
|
||||
item.put(ACCESS_KEY, new AttributeValue("accesskey"))
|
||||
item.put(SECRET_KEY, new AttributeValue("secretkey"))
|
||||
item.put(LOCK_STATUS, new AttributeValue("Locked"))
|
||||
val user = fromItem(item).unsafeRunSync()
|
||||
|
||||
user.isSuper shouldBe false
|
||||
}
|
||||
|
||||
"sets the lockStatus to Unlocked if the given value is invalid" in {
|
||||
val item = new java.util.HashMap[String, AttributeValue]()
|
||||
item.put(USER_ID, new AttributeValue("ok"))
|
||||
item.put(USER_NAME, new AttributeValue("ok"))
|
||||
item.put(CREATED, new AttributeValue().withN("0"))
|
||||
item.put(ACCESS_KEY, new AttributeValue("accesskey"))
|
||||
item.put(SECRET_KEY, new AttributeValue("secretkey"))
|
||||
item.put(LOCK_STATUS, new AttributeValue("lock_status"))
|
||||
val user = fromItem(item).unsafeRunSync()
|
||||
|
||||
user.lockStatus shouldBe LockStatus.Unlocked
|
||||
}
|
||||
}
|
||||
|
||||
"DynamoDBUserRepository.getUser" should {
|
||||
|
Reference in New Issue
Block a user