2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-09-02 15:25:44 +00:00

add locked user account with auth check to api (#172)

add locked user account with auth check to api
This commit is contained in:
Britney Wright
2018-09-17 12:07:18 -04:00
committed by GitHub
parent 82f292a442
commit d8876d369f
25 changed files with 407 additions and 87 deletions

View File

@@ -0,0 +1,29 @@
from utils import *
from hamcrest import *
from vinyldns_python import VinylDNSClient
from dns.resolver import *
from vinyldns_context import VinylDNSTestContext
def test_request_fails_when_user_account_is_locked():
"""
Test request fails with Forbidden (403) when user account is locked
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'lockedAccessKey', 'lockedSecretKey')
client.list_batch_change_summaries(status=403)
def test_request_fails_when_user_is_not_found():
"""
Test request fails with Unauthorized (401) when user account is not found
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'unknownAccessKey', 'anyAccessSecretKey')
client.list_batch_change_summaries(status=401)
def test_request_succeeds_when_user_is_found_and_not_locked():
"""
Test request success with Success (200) when user account is found and not locked
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'okAccessKey', 'okSecretKey')
client.list_batch_change_summaries(status=200)

View File

@@ -93,7 +93,7 @@ def test_list_group_members_start_from(shared_zone_test_context):
# members has one more because admins are added as members # members has one more because admins are added as members
assert_that(result['members'], has_length(len(members) + 1)) assert_that(result['members'], has_length(len(members) + 1))
assert_that(result['members'], has_item({ 'id': 'ok'})) assert_that(result['members'], has_item({ 'lockStatus': 'Unlocked', 'id': 'ok'}))
result_member_ids = map(lambda member: member['id'], result['members']) result_member_ids = map(lambda member: member['id'], result['members'])
for user in members: for user in members:
assert_that(result_member_ids, has_item(user['id'])) assert_that(result_member_ids, has_item(user['id']))

View File

@@ -101,7 +101,7 @@ class JdbcZoneRepositoryIntegrationSpec
if (num == 1) z.addACLRule(dummyAclRule) else z if (num == 1) z.addACLRule(dummyAclRule) else z
} }
private val superUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq()) private val jdbcSuperUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq())
private def testZone(name: String, adminGroupId: String = testZoneAdminGroupId) = private def testZone(name: String, adminGroupId: String = testZoneAdminGroupId) =
okZone.copy(name = name, id = UUID.randomUUID().toString, adminGroupId = adminGroupId) okZone.copy(name = name, id = UUID.randomUUID().toString, adminGroupId = adminGroupId)
@@ -410,7 +410,7 @@ class JdbcZoneRepositoryIntegrationSpec
val f = val f =
for { for {
_ <- saveZones(testZones) _ <- saveZones(testZones)
retrieved <- repo.listZones(superUserAuth) retrieved <- repo.listZones(jdbcSuperUserAuth)
} yield retrieved } yield retrieved
whenReady(f.unsafeToFuture(), timeout) { retrieved => whenReady(f.unsafeToFuture(), timeout) { retrieved =>
@@ -431,7 +431,7 @@ class JdbcZoneRepositoryIntegrationSpec
val f = val f =
for { for {
_ <- saveZones(testZones) _ <- saveZones(testZones)
retrieved <- repo.listZones(superUserAuth, zoneNameFilter = Some("system")) retrieved <- repo.listZones(jdbcSuperUserAuth, zoneNameFilter = Some("system"))
} yield retrieved } yield retrieved
whenReady(f.unsafeToFuture(), timeout) { retrieved => whenReady(f.unsafeToFuture(), timeout) { retrieved =>
@@ -471,19 +471,19 @@ class JdbcZoneRepositoryIntegrationSpec
whenReady(saveZones(testZones).unsafeToFuture(), timeout) { _ => whenReady(saveZones(testZones).unsafeToFuture(), timeout) { _ =>
whenReady( whenReady(
repo.listZones(superUserAuth, offset = None, pageSize = 4).unsafeToFuture(), repo.listZones(jdbcSuperUserAuth, offset = None, pageSize = 4).unsafeToFuture(),
timeout) { firstPage => timeout) { firstPage =>
(firstPage should contain).theSameElementsInOrderAs(expectedFirstPage) (firstPage should contain).theSameElementsInOrderAs(expectedFirstPage)
} }
whenReady( whenReady(
repo.listZones(superUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(), repo.listZones(jdbcSuperUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(),
timeout) { secondPage => timeout) { secondPage =>
(secondPage should contain).theSameElementsInOrderAs(expectedSecondPage) (secondPage should contain).theSameElementsInOrderAs(expectedSecondPage)
} }
whenReady( whenReady(
repo.listZones(superUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(), repo.listZones(jdbcSuperUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(),
timeout) { thirdPage => timeout) { thirdPage =>
(thirdPage should contain).theSameElementsInOrderAs(expectedThirdPage) (thirdPage should contain).theSameElementsInOrderAs(expectedThirdPage)
} }

View File

@@ -21,6 +21,7 @@ import java.util.UUID
import org.joda.time.DateTime import org.joda.time.DateTime
import vinyldns.core.domain.membership.GroupChangeType.GroupChangeType import vinyldns.core.domain.membership.GroupChangeType.GroupChangeType
import vinyldns.core.domain.membership.GroupStatus.GroupStatus import vinyldns.core.domain.membership.GroupStatus.GroupStatus
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership._ import vinyldns.core.domain.membership._
/* This is the new View model for Groups, do not surface the Group model directly any more */ /* This is the new View model for Groups, do not surface the Group model directly any more */
@@ -73,7 +74,8 @@ case class UserInfo(
firstName: Option[String] = None, firstName: Option[String] = None,
lastName: Option[String] = None, lastName: Option[String] = None,
email: Option[String] = None, email: Option[String] = None,
created: Option[DateTime] = None created: Option[DateTime] = None,
lockStatus: LockStatus = LockStatus.Unlocked
) )
object UserInfo { object UserInfo {
def apply(user: User): UserInfo = def apply(user: User): UserInfo =
@@ -83,7 +85,8 @@ object UserInfo {
firstName = user.firstName, firstName = user.firstName,
lastName = user.lastName, lastName = user.lastName,
email = user.email, email = user.email,
created = Some(user.created) created = Some(user.created),
lockStatus = user.lockStatus
) )
} }

View File

@@ -19,6 +19,7 @@ package vinyldns.api.domain.membership
import cats.implicits._ import cats.implicits._
import vinyldns.api.Interfaces._ import vinyldns.api.Interfaces._
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.zone.ZoneRepository import vinyldns.core.domain.zone.ZoneRepository
import vinyldns.core.domain.membership._ import vinyldns.core.domain.membership._
@@ -55,7 +56,7 @@ class MembershipService(
for { for {
existingGroup <- getExistingGroup(groupId) existingGroup <- getExistingGroup(groupId)
newGroup = existingGroup.withUpdates(name, email, description, memberIds, adminUserIds) newGroup = existingGroup.withUpdates(name, email, description, memberIds, adminUserIds)
_ <- isAdmin(existingGroup, authPrincipal).toResult _ <- isGroupAdmin(existingGroup, authPrincipal).toResult
addedMembers = newGroup.memberIds.diff(existingGroup.memberIds) addedMembers = newGroup.memberIds.diff(existingGroup.memberIds)
removedMembers = existingGroup.memberIds.diff(newGroup.memberIds) removedMembers = existingGroup.memberIds.diff(newGroup.memberIds)
_ <- hasMembersAndAdmins(newGroup).toResult _ <- hasMembersAndAdmins(newGroup).toResult
@@ -72,7 +73,7 @@ class MembershipService(
def deleteGroup(groupId: String, authPrincipal: AuthPrincipal): Result[Group] = def deleteGroup(groupId: String, authPrincipal: AuthPrincipal): Result[Group] =
for { for {
existingGroup <- getExistingGroup(groupId) existingGroup <- getExistingGroup(groupId)
_ <- isAdmin(existingGroup, authPrincipal).toResult _ <- isGroupAdmin(existingGroup, authPrincipal).toResult
_ <- groupCanBeDeleted(existingGroup) _ <- groupCanBeDeleted(existingGroup)
_ <- groupChangeRepo _ <- groupChangeRepo
.save(GroupChange.forDelete(existingGroup, authPrincipal)) .save(GroupChange.forDelete(existingGroup, authPrincipal))
@@ -174,6 +175,12 @@ class MembershipService(
.getUsers(userIds, startFrom, pageSize) .getUsers(userIds, startFrom, pageSize)
.toResult[ListUsersResults] .toResult[ListUsersResults]
def getExistingUser(userId: String): Result[User] =
userRepo
.getUser(userId)
.orFail(UserNotFoundError(s"User with ID $userId was not found"))
.toResult[User]
def getExistingGroup(groupId: String): Result[Group] = def getExistingGroup(groupId: String): Result[Group] =
groupRepo groupRepo
.getGroup(groupId) .getGroup(groupId)
@@ -222,4 +229,15 @@ class MembershipService(
} }
} }
.toResult .toResult
def updateUserLockStatus(
userId: String,
lockStatus: LockStatus,
authPrincipal: AuthPrincipal): Result[User] =
for {
_ <- isSuperAdmin(authPrincipal).toResult
existingUser <- getExistingUser(userId)
newUser = existingUser.updateUserLockStatus(lockStatus)
_ <- userRepo.save(newUser).toResult[User]
} yield newUser
} }

View File

@@ -18,6 +18,7 @@ package vinyldns.api.domain.membership
import vinyldns.api.Interfaces.Result import vinyldns.api.Interfaces.Result
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership._ import vinyldns.core.domain.membership._
trait MembershipServiceAlgebra { trait MembershipServiceAlgebra {
@@ -56,4 +57,9 @@ trait MembershipServiceAlgebra {
startFrom: Option[String], startFrom: Option[String],
maxItems: Int, maxItems: Int,
authPrincipal: AuthPrincipal): Result[ListGroupChangesResponse] authPrincipal: AuthPrincipal): Result[ListGroupChangesResponse]
def updateUserLockStatus(
userId: String,
lockStatus: LockStatus,
authPrincipal: AuthPrincipal): Result[User]
} }

View File

@@ -28,11 +28,16 @@ object MembershipValidations {
group.memberIds.nonEmpty && group.adminUserIds.nonEmpty group.memberIds.nonEmpty && group.adminUserIds.nonEmpty
} }
def isAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] = def isGroupAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) { ensuring(NotAuthorizedError("Not authorized")) {
group.adminUserIds.contains(authPrincipal.userId) || authPrincipal.signedInUser.isSuper group.adminUserIds.contains(authPrincipal.userId) || authPrincipal.signedInUser.isSuper
} }
def isSuperAdmin(authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) {
authPrincipal.signedInUser.isSuper
}
def canSeeGroup(groupId: String, authPrincipal: AuthPrincipal): Either[Throwable, Unit] = def canSeeGroup(groupId: String, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) { ensuring(NotAuthorizedError("Not authorized")) {
authPrincipal.isAuthorized(groupId) authPrincipal.isAuthorized(groupId)

View File

@@ -51,7 +51,20 @@ object TestDataLoader {
id = "dummy", id = "dummy",
created = DateTime.now.secondOfDay().roundFloorCopy(), created = DateTime.now.secondOfDay().roundFloorCopy(),
accessKey = "dummyAccessKey", accessKey = "dummyAccessKey",
secretKey = "dummySecretKey") secretKey = "dummySecretKey"
)
final val lockedUser = User(
userName = "locked",
id = "locked",
created = DateTime.now.secondOfDay().roundFloorCopy(),
accessKey = "lockedAccessKey",
secretKey = "lockedSecretKey",
firstName = Some("Locked"),
lastName = Some("User"),
email = Some("testlocked@test.com"),
isSuper = false,
lockStatus = LockStatus.Locked
)
final val listOfDummyUsers: List[User] = List.range(0, 200).map { runner => final val listOfDummyUsers: List[User] = List.range(0, 200).map { runner =>
User( User(
userName = "name-dummy%03d".format(runner), userName = "name-dummy%03d".format(runner),
@@ -117,7 +130,7 @@ object TestDataLoader {
) )
def loadTestData(repository: UserRepository): IO[List[User]] = def loadTestData(repository: UserRepository): IO[List[User]] =
(testUser :: okUser :: dummyUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser :: (testUser :: okUser :: dummyUser :: lockedUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser ::
listZeroBatchChangeSummariesUser :: zoneHistoryUser :: listOfDummyUsers).map { user => listZeroBatchChangeSummariesUser :: zoneHistoryUser :: listOfDummyUsers).map { user =>
val encrypted = val encrypted =
if (VinylDNSConfig.encryptUserSecrets) if (VinylDNSConfig.encryptUserSecrets)

View File

@@ -23,7 +23,7 @@ import cats.implicits._
import org.joda.time.DateTime import org.joda.time.DateTime
import org.json4s._ import org.json4s._
import vinyldns.api.domain.membership._ import vinyldns.api.domain.membership._
import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus} import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus, LockStatus}
object MembershipJsonProtocol { object MembershipJsonProtocol {
final case class CreateGroupInput( final case class CreateGroupInput(
@@ -52,6 +52,7 @@ trait MembershipJsonProtocol extends JsonValidation {
GroupChangeInfoSerializer, GroupChangeInfoSerializer,
CreateGroupInputSerializer, CreateGroupInputSerializer,
UpdateGroupInputSerializer, UpdateGroupInputSerializer,
JsonEnumV(LockStatus),
JsonEnumV(GroupStatus), JsonEnumV(GroupStatus),
JsonEnumV(GroupChangeType) JsonEnumV(GroupChangeType)
) )

View File

@@ -23,7 +23,7 @@ import vinyldns.api.domain.membership._
import vinyldns.api.domain.zone.NotAuthorizedError import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput} import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.Group import vinyldns.core.domain.membership.{Group, LockStatus}
trait MembershipRoute extends Directives { trait MembershipRoute extends Directives {
this: VinylDNSJsonProtocol with VinylDNSDirectives with JsonValidationRejection => this: VinylDNSJsonProtocol with VinylDNSDirectives with JsonValidationRejection =>
@@ -168,6 +168,18 @@ trait MembershipRoute extends Directives {
} }
} }
} }
} ~
(put & path("users" / Segment / "lock") & monitor("Endpoint.lockUser")) { id =>
execute(membershipService.updateUserLockStatus(id, LockStatus.Locked, authPrincipal)) {
user =>
complete(StatusCodes.OK, UserInfo(user))
}
} ~
(put & path("users" / Segment / "unlock") & monitor("Endpoint.unlockUser")) { id =>
execute(membershipService.updateUserLockStatus(id, LockStatus.Unlocked, authPrincipal)) {
user =>
complete(StatusCodes.OK, UserInfo(user))
}
} }
} }

View File

@@ -17,8 +17,7 @@
package vinyldns.api.route package vinyldns.api.route
import akka.http.scaladsl.model.HttpRequest import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause import akka.http.scaladsl.server.RequestContext
import akka.http.scaladsl.server.{AuthenticationFailedRejection, RequestContext}
import cats.effect._ import cats.effect._
import cats.syntax.all._ import cats.syntax.all._
import vinyldns.api.VinylDNSConfig import vinyldns.api.VinylDNSConfig
@@ -27,12 +26,14 @@ import vinyldns.api.domain.auth.{AuthPrincipalProvider, MembershipAuthPrincipalP
import vinyldns.core.crypto.CryptoAlgebra import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.route.Monitored import vinyldns.core.route.Monitored
import vinyldns.core.domain.membership.LockStatus
import scala.util.matching.Regex import scala.util.matching.Regex
sealed abstract class VinylDNSAuthenticationError(msg: String) extends Throwable(msg) sealed abstract class VinylDNSAuthenticationError(msg: String) extends Throwable(msg)
final case class AuthMissing(msg: String) extends VinylDNSAuthenticationError(msg) final case class AuthMissing(msg: String) extends VinylDNSAuthenticationError(msg)
final case class AuthRejected(reason: String) extends VinylDNSAuthenticationError(reason) final case class AuthRejected(reason: String) extends VinylDNSAuthenticationError(reason)
final case class AccountLocked(reason: String) extends VinylDNSAuthenticationError(reason)
trait VinylDNSAuthentication extends Monitored { trait VinylDNSAuthentication extends Monitored {
val authenticator: Aws4Authenticator val authenticator: Aws4Authenticator
@@ -131,8 +132,14 @@ trait VinylDNSAuthentication extends Monitored {
if (encryptionEnabled) crypto.decrypt(str) else str if (encryptionEnabled) crypto.decrypt(str) else str
def getAuthPrincipal(accessKey: String): IO[AuthPrincipal] = def getAuthPrincipal(accessKey: String): IO[AuthPrincipal] =
authPrincipalProvider.getAuthPrincipal(accessKey).map { authPrincipalProvider.getAuthPrincipal(accessKey).flatMap {
_.getOrElse(throw AuthRejected(s"Account with accessKey $accessKey specified was not found")) case Some(ok) =>
if (ok.signedInUser.lockStatus == LockStatus.Locked) {
IO.raiseError(
AccountLocked(s"Account with username ${ok.signedInUser.userName} is locked"))
} else IO.pure(ok)
case None =>
IO.raiseError(AuthRejected(s"Account with accessKey $accessKey specified was not found"))
} }
} }
@@ -141,16 +148,14 @@ class VinylDNSAuthenticator(
val authPrincipalProvider: AuthPrincipalProvider) val authPrincipalProvider: AuthPrincipalProvider)
extends VinylDNSAuthentication { extends VinylDNSAuthentication {
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] = def apply(
authenticate(ctx, content).attempt.map { ctx: RequestContext,
case Right(ok) => Right(ok) content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
case Left(_: AuthMissing) => // Need to refactor authenticate to be an IO[Either[E, A]] instead of how it is implemented, for the time being...
Left(AuthenticationFailedRejection.CredentialsMissing) authenticate(ctx, content).attempt.flatMap {
case Left(_: AuthRejected) => case Left(e: VinylDNSAuthenticationError) => IO.pure(Left(e))
Left(AuthenticationFailedRejection.CredentialsRejected) case Right(ok) => IO.pure(Right(ok))
case Left(e: Throwable) => case Left(e) => IO.raiseError(e)
// throw here as some unexpected exception occurred
throw e
} }
} }
@@ -159,6 +164,8 @@ object VinylDNSAuthenticator {
lazy val authPrincipalProvider = MembershipAuthPrincipalProvider() lazy val authPrincipalProvider = MembershipAuthPrincipalProvider()
lazy val authenticator = new VinylDNSAuthenticator(aws4Authenticator, authPrincipalProvider) lazy val authenticator = new VinylDNSAuthenticator(aws4Authenticator, authPrincipalProvider)
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] = def apply(
ctx: RequestContext,
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
authenticator.apply(ctx, content) authenticator.apply(ctx, content)
} }

View File

@@ -17,7 +17,6 @@
package vinyldns.api.route package vinyldns.api.route
import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes} import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.RouteResult.{Complete, Rejected} import akka.http.scaladsl.server.RouteResult.{Complete, Rejected}
import akka.http.scaladsl.server._ import akka.http.scaladsl.server._
import akka.http.scaladsl.server.directives.BasicDirectives import akka.http.scaladsl.server.directives.BasicDirectives
@@ -42,7 +41,7 @@ trait VinylDNSDirectives extends Directives {
*/ */
def vinyldnsAuthenticator( def vinyldnsAuthenticator(
ctx: RequestContext, ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] = content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
VinylDNSAuthenticator(ctx, content) VinylDNSAuthenticator(ctx, content)
def authenticate: Directive1[AuthPrincipal] = def authenticate: Directive1[AuthPrincipal] =
@@ -53,17 +52,25 @@ trait VinylDNSDirectives extends Directives {
.flatMap { .flatMap {
case Right(authPrincipal) case Right(authPrincipal)
provide(authPrincipal) provide(authPrincipal)
case Left(cause) case Left(e)
// we need to finish the result, rejections will proceed and ultimately complete(handleAuthenticateError(e))
// we can fail with a different rejection }
complete( }
}
}
def handleAuthenticateError(error: VinylDNSAuthenticationError): HttpResponse =
error match {
case AccountLocked(err) =>
HttpResponse(
status = StatusCodes.Forbidden,
entity = HttpEntity(s"Authentication Failed: $err")
)
case e =>
HttpResponse( HttpResponse(
status = StatusCodes.Unauthorized, status = StatusCodes.Unauthorized,
entity = HttpEntity(s"Authentication Failed: $cause") entity = HttpEntity(s"Authentication Failed: ${e.getMessage}")
)) )
}
}
}
} }
/* Adds monitoring to an Endpoint. The name will be surfaced in JMX */ /* Adds monitoring to an Endpoint. The name will be surfaced in JMX */

View File

@@ -30,8 +30,9 @@ trait GroupTestData { this: Matchers =>
val okUser: User = TestDataLoader.okUser val okUser: User = TestDataLoader.okUser
val dummyUser: User = TestDataLoader.dummyUser val dummyUser: User = TestDataLoader.dummyUser
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers val lockedUser: User = TestDataLoader.lockedUser
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers
val okUserInfo: UserInfo = UserInfo(okUser) val okUserInfo: UserInfo = UserInfo(okUser)
val dummyUserInfo: UserInfo = UserInfo(dummyUser) val dummyUserInfo: UserInfo = UserInfo(dummyUser)
@@ -91,6 +92,7 @@ trait GroupTestData { this: Matchers =>
val noGroupsUserAuth: AuthPrincipal = AuthPrincipal(okUser, Seq()) val noGroupsUserAuth: AuthPrincipal = AuthPrincipal(okUser, Seq())
val deletedGroupAuth: AuthPrincipal = AuthPrincipal(okUser, Seq(deletedGroup.id)) val deletedGroupAuth: AuthPrincipal = AuthPrincipal(okUser, Seq(deletedGroup.id))
val dummyUserAuth: AuthPrincipal = AuthPrincipal(dummyUser, Seq(dummyGroup.id)) val dummyUserAuth: AuthPrincipal = AuthPrincipal(dummyUser, Seq(dummyGroup.id))
val lockedUserAuth: AuthPrincipal = AuthPrincipal(lockedUser, Seq())
val listOfDummyGroupsAuth: AuthPrincipal = AuthPrincipal(dummyUser, listOfDummyGroups.map(_.id)) val listOfDummyGroupsAuth: AuthPrincipal = AuthPrincipal(dummyUser, listOfDummyGroups.map(_.id))
val memberOkZoneAuthorized: Zone = Zone( val memberOkZoneAuthorized: Zone = Zone(

View File

@@ -39,6 +39,7 @@ trait VinylDNSTestData {
created = DateTime.now.secondOfDay().roundFloorCopy()) created = DateTime.now.secondOfDay().roundFloorCopy())
val okAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.okUser, Seq(grp.id)) val okAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.okUser, Seq(grp.id))
val notAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.dummyUser, Seq.empty) val notAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.dummyUser, Seq.empty)
val lockedAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.lockedUser, Seq.empty)
val testConnection: Option[ZoneConnection] = Some( val testConnection: Option[ZoneConnection] = Some(
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1")) ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1"))

View File

@@ -23,7 +23,7 @@ import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar import org.scalatest.mockito.MockitoSugar
import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec} import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
import vinyldns.api.Interfaces._ import vinyldns.api.Interfaces._
import vinyldns.api.{GroupTestData, ResultHelpers} import vinyldns.api.{GroupTestData, ResultHelpers, VinylDNSTestData}
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.zone.{ZoneRepository, _} import vinyldns.core.domain.zone.{ZoneRepository, _}
import cats.effect._ import cats.effect._
@@ -37,6 +37,7 @@ class MembershipServiceSpec
with BeforeAndAfterEach with BeforeAndAfterEach
with ResultHelpers with ResultHelpers
with GroupTestData with GroupTestData
with VinylDNSTestData
with EitherMatchers { with EitherMatchers {
private val mockGroupRepo = mock[GroupRepository] private val mockGroupRepo = mock[GroupRepository]
@@ -750,5 +751,72 @@ class MembershipServiceSpec
error shouldBe a[InvalidGroupRequestError] error shouldBe a[InvalidGroupRequestError]
} }
} }
"updateUserLockStatus" should {
"save the update and lock the user account" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(Some(okUser))).when(mockUserRepo).getUser(okUser.id)
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
.value
.unsafeRunSync()
val userCaptor = ArgumentCaptor.forClass(classOf[User])
verify(mockUserRepo).save(userCaptor.capture())
val savedUser = userCaptor.getValue
savedUser.lockStatus shouldBe LockStatus.Locked
savedUser.id shouldBe okUser.id
}
"save the update and unlock the user account" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(Some(lockedUser))).when(mockUserRepo).getUser(lockedUser.id)
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
underTest
.updateUserLockStatus(lockedUser.id, LockStatus.Unlocked, superUserAuth)
.value
.unsafeRunSync()
val userCaptor = ArgumentCaptor.forClass(classOf[User])
verify(mockUserRepo).save(userCaptor.capture())
val savedUser = userCaptor.getValue
savedUser.lockStatus shouldBe LockStatus.Unlocked
savedUser.id shouldBe lockedUser.id
}
"return an error if the signed in user is not a super user" in {
val error = leftResultOf(
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, dummyUserAuth)
.value)
error shouldBe a[NotAuthorizedError]
}
"return an error if the requested user is not found" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(None)).when(mockUserRepo).getUser(okUser.id)
val error = leftResultOf(
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
.value)
error shouldBe a[UserNotFoundError]
}
}
} }
} }

View File

@@ -56,17 +56,17 @@ class MembershipValidationsSpec
"isAdmin" should { "isAdmin" should {
"return true when the user is in admin group" in { "return true when the user is in admin group" in {
isAdmin(okGroup, okUserAuth) should be(right) isGroupAdmin(okGroup, okUserAuth) should be(right)
} }
"return true when the user is a super user" in { "return true when the user is a super user" in {
val user = User("some", "new", "user", isSuper = true) val user = User("some", "new", "user", isSuper = true)
val superAuth = AuthPrincipal(user, Seq()) val superAuth = AuthPrincipal(user, Seq())
isAdmin(okGroup, superAuth) should be(right) isGroupAdmin(okGroup, superAuth) should be(right)
} }
"return an error when the user has no access and is not super" in { "return an error when the user has no access and is not super" in {
val user = User("some", "new", "user") val user = User("some", "new", "user")
val nonSuperAuth = AuthPrincipal(user, Seq()) val nonSuperAuth = AuthPrincipal(user, Seq())
val error = leftValue(isAdmin(okGroup, nonSuperAuth)) val error = leftValue(isGroupAdmin(okGroup, nonSuperAuth))
error shouldBe an[NotAuthorizedError] error shouldBe an[NotAuthorizedError]
} }
} }

View File

@@ -30,10 +30,11 @@ import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
import vinyldns.api.Interfaces._ import vinyldns.api.Interfaces._
import vinyldns.api.domain.membership._ import vinyldns.api.domain.membership._
import vinyldns.core.domain.auth.AuthPrincipal import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.Group import vinyldns.core.domain.membership.{Group, LockStatus}
import vinyldns.api.domain.zone.NotAuthorizedError import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput} import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
import vinyldns.api.{GroupTestData, VinylDNSTestData} import vinyldns.api.{GroupTestData, VinylDNSTestData}
import vinyldns.core.domain.membership.LockStatus.LockStatus
class MembershipRoutingSpec class MembershipRoutingSpec
extends WordSpec extends WordSpec
@@ -671,4 +672,64 @@ class MembershipRoutingSpec
} }
} }
} }
"PUT update user lock status" should {
"return a 200 response with the user locked" in {
val updatedUser = okUser.copy(lockStatus = LockStatus.Locked)
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(updatedUser))
.when(membershipService)
.updateUserLockStatus("ok", LockStatus.Locked, superUserAuth)
Put("/users/ok/lock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.OK
val result = responseAs[UserInfo]
result.id shouldBe okUser.id
result.lockStatus shouldBe LockStatus.Locked
}
}
"return a 200 response with the user unlocked" in {
val updatedUser = lockedUser.copy(lockStatus = LockStatus.Unlocked)
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(updatedUser))
.when(membershipService)
.updateUserLockStatus("locked", LockStatus.Unlocked, superUserAuth)
Put("/users/locked/unlock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.OK
val result = responseAs[UserInfo]
result.id shouldBe lockedUser.id
result.lockStatus shouldBe LockStatus.Unlocked
}
}
"return a 404 Not Found when the user is not found" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(UserNotFoundError("fail")))
.when(membershipService)
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
Put("/users/notFound/lock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.NotFound
}
}
"return a 403 Forbidden when not authorized" in {
doReturn(result(NotAuthorizedError("fail")))
.when(membershipService)
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
Put("/users/forbidden/lock") ~> membershipRoute(okGroupAuth) ~> check {
status shouldBe StatusCodes.Forbidden
}
}
}
} }

View File

@@ -17,7 +17,6 @@
package vinyldns.api.route package vinyldns.api.route
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest, StatusCodes} import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest, StatusCodes}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.{Directives, RequestContext, Route} import akka.http.scaladsl.server.{Directives, RequestContext, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect._ import cats.effect._
@@ -485,7 +484,7 @@ class RecordSetRoutingSpec
override def vinyldnsAuthenticator( override def vinyldnsAuthenticator(
ctx: RequestContext, ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] = content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
IO.pure(Right(okAuth)) IO.pure(Right(okAuth))
private def rsJson(recordSet: RecordSet): String = private def rsJson(recordSet: RecordSet): String =

View File

@@ -17,11 +17,6 @@
package vinyldns.api.route package vinyldns.api.route
import akka.http.scaladsl.model.{HttpHeader, HttpRequest} import akka.http.scaladsl.model.{HttpHeader, HttpRequest}
import akka.http.scaladsl.server.AuthenticationFailedRejection.{
Cause,
CredentialsMissing,
CredentialsRejected
}
import akka.http.scaladsl.server.RequestContext import akka.http.scaladsl.server.RequestContext
import cats.effect._ import cats.effect._
import org.mockito.Matchers._ import org.mockito.Matchers._
@@ -29,15 +24,13 @@ import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar import org.scalatest.mockito.MockitoSugar
import org.scalatest.{Matchers, WordSpec} import org.scalatest.{Matchers, WordSpec}
import vinyldns.api.domain.auth.AuthPrincipalProvider import vinyldns.api.domain.auth.AuthPrincipalProvider
import vinyldns.api.{GroupTestData, ResultHelpers} import vinyldns.api.{GroupTestData}
import vinyldns.core.crypto.CryptoAlgebra import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.auth.AuthPrincipal
class VinylDNSAuthenticatorSpec class VinylDNSAuthenticatorSpec
extends WordSpec extends WordSpec
with Matchers with Matchers
with MockitoSugar with MockitoSugar
with ResultHelpers
with GroupTestData { with GroupTestData {
private val mockAuthenticator = mock[Aws4Authenticator] private val mockAuthenticator = mock[Aws4Authenticator]
private val mockAuthPrincipalProvider = mock[AuthPrincipalProvider] private val mockAuthPrincipalProvider = mock[AuthPrincipalProvider]
@@ -87,8 +80,7 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator) .when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String]) .authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Right(okUserAuth) result shouldBe Right(okUserAuth)
} }
"fail if missing Authorization header" in { "fail if missing Authorization header" in {
@@ -109,9 +101,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator) .when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String]) .authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthMissing("Authorization header not found"))
result shouldBe Left(CredentialsMissing)
} }
"fail if Authorization header can not be parsed" in { "fail if Authorization header can not be parsed" in {
val fakeHttpHeader = mock[HttpHeader] val fakeHttpHeader = mock[HttpHeader]
@@ -125,9 +116,8 @@ class VinylDNSAuthenticatorSpec
val context: RequestContext = mock[RequestContext] val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request doReturn(httpRequest).when(context).request
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthRejected("Authorization header could not be parsed"))
result shouldBe Left(CredentialsRejected)
} }
"fail if the access key is missing" in { "fail if the access key is missing" in {
val fakeHttpHeader = mock[HttpHeader] val fakeHttpHeader = mock[HttpHeader]
@@ -149,9 +139,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator) .when(mockAuthenticator)
.extractAccessKey(any[String]) .extractAccessKey(any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthMissing("accessKey not found"))
result shouldBe Left(CredentialsMissing)
} }
"fail if the access key can not be retrieved" in { "fail if the access key can not be retrieved" in {
val fakeHttpHeader = mock[HttpHeader] val fakeHttpHeader = mock[HttpHeader]
@@ -173,9 +162,34 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator) .when(mockAuthenticator)
.extractAccessKey(any[String]) .extractAccessKey(any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthRejected("Invalid authorization header"))
result shouldBe Left(CredentialsRejected) }
"fail if the user is locked" in {
val fakeHttpHeader = mock[HttpHeader]
doReturn("Authorization").when(fakeHttpHeader).name
val header = "AWS4-HMAC-SHA256" +
" Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request," +
" SignedHeaders=host;range;x-amz-date," +
" Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
doReturn(header).when(fakeHttpHeader).value
val httpRequest: HttpRequest = HttpRequest().withHeaders(List(fakeHttpHeader))
val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request
doReturn(lockedUser.accessKey)
.when(mockAuthenticator)
.extractAccessKey(any[String])
doReturn(IO.pure(Some(lockedUserAuth)))
.when(mockAuthPrincipalProvider)
.getAuthPrincipal(any[String])
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AccountLocked("Account with username locked is locked"))
} }
"fail if the user can not be found" in { "fail if the user can not be found" in {
val fakeHttpHeader = mock[HttpHeader] val fakeHttpHeader = mock[HttpHeader]
@@ -192,7 +206,7 @@ class VinylDNSAuthenticatorSpec
val context: RequestContext = mock[RequestContext] val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request doReturn(httpRequest).when(context).request
doReturn(okUser.accessKey) doReturn("fakeKey")
.when(mockAuthenticator) .when(mockAuthenticator)
.extractAccessKey(any[String]) .extractAccessKey(any[String])
@@ -201,9 +215,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthPrincipalProvider) .when(mockAuthPrincipalProvider)
.getAuthPrincipal(any[String]) .getAuthPrincipal(any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthRejected("Account with accessKey fakeKey specified was not found"))
result shouldBe Left(CredentialsRejected)
} }
"fail if signatures can not be validated" in { "fail if signatures can not be validated" in {
val fakeHttpHeader = mock[HttpHeader] val fakeHttpHeader = mock[HttpHeader]
@@ -233,9 +246,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator) .when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String]) .authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result = val result = underTest.apply(context, "").unsafeRunSync()
await[Either[Cause, AuthPrincipal]](underTest.apply(context, "")) result shouldBe Left(AuthRejected("Request signature could not be validated"))
result shouldBe Left(CredentialsRejected)
} }
} }
} }

View File

@@ -18,7 +18,7 @@ package vinyldns.api.route
import java.io.IOException import java.io.IOException
import akka.http.scaladsl.model.{HttpResponse, StatusCodes} import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
import akka.http.scaladsl.server.{Directives, Route} import akka.http.scaladsl.server.{Directives, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.testkit.ScalatestRouteTest
import nl.grons.metrics.scala.{Histogram, Meter} import nl.grons.metrics.scala.{Histogram, Meter}
@@ -66,6 +66,26 @@ class VinylDNSDirectivesSpec
override def beforeEach(): Unit = override def beforeEach(): Unit =
reset(mockLatency, mockErrors) reset(mockLatency, mockErrors)
".handleAuthenticateError" should {
"respond with Forbidden status if account is locked" in {
val trythis = handleAuthenticateError(AccountLocked("error"))
trythis shouldBe HttpResponse(
status = StatusCodes.Forbidden,
entity = HttpEntity(s"Authentication Failed: error")
)
}
"respond with Unauthorized status for other authentication errors" in {
val trythis = handleAuthenticateError(AuthRejected("error"))
trythis shouldBe HttpResponse(
status = StatusCodes.Unauthorized,
entity = HttpEntity(s"Authentication Failed: error")
)
}
}
"The monitor directive" should { "The monitor directive" should {
"record when completing an HttpResponse normally" in { "record when completing an HttpResponse normally" in {
Get("/test") ~> testRoute ~> check { Get("/test") ~> testRoute ~> check {

View File

@@ -19,7 +19,6 @@ package vinyldns.api.route
import akka.actor.ActorSystem import akka.actor.ActorSystem
import akka.http.scaladsl.model.StatusCodes._ import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest} import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.{Directives, RequestContext, Route} import akka.http.scaladsl.server.{Directives, RequestContext, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect._ import cats.effect._
@@ -332,7 +331,7 @@ class ZoneRoutingSpec
override def vinyldnsAuthenticator( override def vinyldnsAuthenticator(
ctx: RequestContext, ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] = content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
IO.pure(Right(okAuth)) IO.pure(Right(okAuth))
def zoneJson(name: String, email: String): String = def zoneJson(name: String, email: String): String =

View File

@@ -19,6 +19,12 @@ package vinyldns.core.domain.membership
import java.util.UUID import java.util.UUID
import org.joda.time.DateTime import org.joda.time.DateTime
import vinyldns.core.domain.membership.LockStatus.LockStatus
object LockStatus extends Enumeration {
type LockStatus = Value
val Locked, Unlocked = Value
}
case class User( case class User(
userName: String, userName: String,
@@ -29,5 +35,10 @@ case class User(
email: Option[String] = None, email: Option[String] = None,
created: DateTime = DateTime.now, created: DateTime = DateTime.now,
id: String = UUID.randomUUID().toString, id: String = UUID.randomUUID().toString,
isSuper: Boolean = false isSuper: Boolean = false,
) lockStatus: LockStatus = LockStatus.Unlocked
) {
def updateUserLockStatus(lockStatus: LockStatus): User =
this.copy(lockStatus = lockStatus)
}

View File

@@ -20,8 +20,7 @@ import cats.implicits._
import com.amazonaws.services.dynamodbv2.model.DeleteTableRequest import com.amazonaws.services.dynamodbv2.model.DeleteTableRequest
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import vinyldns.core.crypto.NoOpCrypto import vinyldns.core.crypto.NoOpCrypto
import vinyldns.core.domain.membership.User import vinyldns.core.domain.membership.{User, LockStatus}
import scala.concurrent.duration._ import scala.concurrent.duration._
class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec { class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
@@ -135,5 +134,24 @@ class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
result shouldBe Some(testUser) result shouldBe Some(testUser)
result.get.isSuper shouldBe false result.get.isSuper shouldBe false
} }
"returns the locked flag when true" in {
val testUser = User(
userName = "testSuper",
accessKey = "testSuper",
secretKey = "testUser",
lockStatus = LockStatus.Locked)
val saved = repo.save(testUser).unsafeRunSync()
val result = repo.getUser(saved.id).unsafeRunSync()
result shouldBe Some(testUser)
result.get.lockStatus shouldBe LockStatus.Locked
}
"returns the locked flag when false" in {
val f = repo.getUserByAccessKey(users.head.accessKey).unsafeRunSync()
f shouldBe Some(users.head)
f.get.lockStatus shouldBe LockStatus.Unlocked
}
} }
} }

View File

@@ -26,10 +26,12 @@ import com.amazonaws.services.dynamodbv2.model._
import org.joda.time.DateTime import org.joda.time.DateTime
import org.slf4j.{Logger, LoggerFactory} import org.slf4j.{Logger, LoggerFactory}
import vinyldns.core.crypto.CryptoAlgebra import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.membership.{ListUsersResults, User, UserRepository} import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership.{ListUsersResults, LockStatus, User, UserRepository}
import vinyldns.core.route.Monitored import vinyldns.core.route.Monitored
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.util.Try
object DynamoDBUserRepository { object DynamoDBUserRepository {
@@ -42,6 +44,7 @@ object DynamoDBUserRepository {
private[repository] val ACCESS_KEY = "accesskey" private[repository] val ACCESS_KEY = "accesskey"
private[repository] val SECRET_KEY = "secretkey" private[repository] val SECRET_KEY = "secretkey"
private[repository] val IS_SUPER = "super" private[repository] val IS_SUPER = "super"
private[repository] val LOCK_STATUS = "lockstatus"
private[repository] val USER_NAME_INDEX_NAME = "username_index" private[repository] val USER_NAME_INDEX_NAME = "username_index"
private[repository] val ACCESS_KEY_INDEX_NAME = "access_key_index" private[repository] val ACCESS_KEY_INDEX_NAME = "access_key_index"
@@ -97,6 +100,7 @@ object DynamoDBUserRepository {
item.put(ACCESS_KEY, new AttributeValue(user.accessKey)) item.put(ACCESS_KEY, new AttributeValue(user.accessKey))
item.put(SECRET_KEY, new AttributeValue(crypto.encrypt(user.secretKey))) item.put(SECRET_KEY, new AttributeValue(crypto.encrypt(user.secretKey)))
item.put(IS_SUPER, new AttributeValue().withBOOL(user.isSuper)) item.put(IS_SUPER, new AttributeValue().withBOOL(user.isSuper))
item.put(LOCK_STATUS, new AttributeValue(user.lockStatus.toString))
val firstName = val firstName =
user.firstName.map(new AttributeValue(_)).getOrElse(new AttributeValue().withNULL(true)) user.firstName.map(new AttributeValue(_)).getOrElse(new AttributeValue().withNULL(true))
@@ -110,6 +114,12 @@ object DynamoDBUserRepository {
} }
def fromItem(item: java.util.Map[String, AttributeValue]): IO[User] = IO { def fromItem(item: java.util.Map[String, AttributeValue]): IO[User] = IO {
def userStatus(str: String): LockStatus = Try(LockStatus.withName(str)).getOrElse {
val log: Logger = LoggerFactory.getLogger(classOf[DynamoDBUserRepository])
log.error(s"Invalid locked status value '$str'; defaulting to unlocked")
LockStatus.Unlocked
}
User( User(
id = item.get(USER_ID).getS, id = item.get(USER_ID).getS,
userName = item.get(USER_NAME).getS, userName = item.get(USER_NAME).getS,
@@ -119,7 +129,8 @@ object DynamoDBUserRepository {
firstName = Option(item.get(FIRST_NAME)).flatMap(fn => Option(fn.getS)), firstName = Option(item.get(FIRST_NAME)).flatMap(fn => Option(fn.getS)),
lastName = Option(item.get(LAST_NAME)).flatMap(ln => Option(ln.getS)), lastName = Option(item.get(LAST_NAME)).flatMap(ln => Option(ln.getS)),
email = Option(item.get(EMAIL)).flatMap(e => Option(e.getS)), email = Option(item.get(EMAIL)).flatMap(e => Option(e.getS)),
isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL,
lockStatus = userStatus(item.get(LOCK_STATUS).getS)
) )
} }
} }

View File

@@ -31,6 +31,7 @@ import scala.collection.JavaConverters._
import cats.effect._ import cats.effect._
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
import vinyldns.core.crypto.{CryptoAlgebra, NoOpCrypto} import vinyldns.core.crypto.{CryptoAlgebra, NoOpCrypto}
import vinyldns.core.domain.membership.LockStatus
import vinyldns.dynamodb.DynamoTestConfig import vinyldns.dynamodb.DynamoTestConfig
class DynamoDBUserRepositorySpec class DynamoDBUserRepositorySpec
@@ -72,6 +73,7 @@ class DynamoDBUserRepositorySpec
items.get(LAST_NAME).getS shouldBe okUser.lastName.get items.get(LAST_NAME).getS shouldBe okUser.lastName.get
items.get(EMAIL).getS shouldBe okUser.email.get items.get(EMAIL).getS shouldBe okUser.email.get
items.get(CREATED).getN shouldBe okUser.created.getMillis.toString items.get(CREATED).getN shouldBe okUser.created.getMillis.toString
items.get(LOCK_STATUS).getS shouldBe okUser.lockStatus.toString
} }
"set the first name to null if it is not present" in { "set the first name to null if it is not present" in {
val emptyFirstName = okUser.copy(firstName = None) val emptyFirstName = okUser.copy(firstName = None)
@@ -131,6 +133,7 @@ class DynamoDBUserRepositorySpec
item.put(CREATED, new AttributeValue().withN("0")) item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accessKey")) item.put(ACCESS_KEY, new AttributeValue("accessKey"))
item.put(SECRET_KEY, new AttributeValue("secretkey")) item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("lockstatus"))
val user = fromItem(item).unsafeRunSync() val user = fromItem(item).unsafeRunSync()
user.firstName shouldBe None user.firstName shouldBe None
@@ -151,10 +154,24 @@ class DynamoDBUserRepositorySpec
item.put(CREATED, new AttributeValue().withN("0")) item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accesskey")) item.put(ACCESS_KEY, new AttributeValue("accesskey"))
item.put(SECRET_KEY, new AttributeValue("secretkey")) item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("Locked"))
val user = fromItem(item).unsafeRunSync() val user = fromItem(item).unsafeRunSync()
user.isSuper shouldBe false user.isSuper shouldBe false
} }
"sets the lockStatus to Unlocked if the given value is invalid" in {
val item = new java.util.HashMap[String, AttributeValue]()
item.put(USER_ID, new AttributeValue("ok"))
item.put(USER_NAME, new AttributeValue("ok"))
item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accesskey"))
item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("lock_status"))
val user = fromItem(item).unsafeRunSync()
user.lockStatus shouldBe LockStatus.Unlocked
}
} }
"DynamoDBUserRepository.getUser" should { "DynamoDBUserRepository.getUser" should {