2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-08-31 14:25:30 +00:00

add locked user account with auth check to api (#172)

add locked user account with auth check to api
This commit is contained in:
Britney Wright
2018-09-17 12:07:18 -04:00
committed by GitHub
parent 82f292a442
commit d8876d369f
25 changed files with 407 additions and 87 deletions

View File

@@ -0,0 +1,29 @@
from utils import *
from hamcrest import *
from vinyldns_python import VinylDNSClient
from dns.resolver import *
from vinyldns_context import VinylDNSTestContext
def test_request_fails_when_user_account_is_locked():
"""
Test request fails with Forbidden (403) when user account is locked
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'lockedAccessKey', 'lockedSecretKey')
client.list_batch_change_summaries(status=403)
def test_request_fails_when_user_is_not_found():
"""
Test request fails with Unauthorized (401) when user account is not found
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'unknownAccessKey', 'anyAccessSecretKey')
client.list_batch_change_summaries(status=401)
def test_request_succeeds_when_user_is_found_and_not_locked():
"""
Test request success with Success (200) when user account is found and not locked
"""
client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, 'okAccessKey', 'okSecretKey')
client.list_batch_change_summaries(status=200)

View File

@@ -93,7 +93,7 @@ def test_list_group_members_start_from(shared_zone_test_context):
# members has one more because admins are added as members
assert_that(result['members'], has_length(len(members) + 1))
assert_that(result['members'], has_item({ 'id': 'ok'}))
assert_that(result['members'], has_item({ 'lockStatus': 'Unlocked', 'id': 'ok'}))
result_member_ids = map(lambda member: member['id'], result['members'])
for user in members:
assert_that(result_member_ids, has_item(user['id']))

View File

@@ -101,7 +101,7 @@ class JdbcZoneRepositoryIntegrationSpec
if (num == 1) z.addACLRule(dummyAclRule) else z
}
private val superUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq())
private val jdbcSuperUserAuth = AuthPrincipal(dummyUser.copy(isSuper = true), Seq())
private def testZone(name: String, adminGroupId: String = testZoneAdminGroupId) =
okZone.copy(name = name, id = UUID.randomUUID().toString, adminGroupId = adminGroupId)
@@ -410,7 +410,7 @@ class JdbcZoneRepositoryIntegrationSpec
val f =
for {
_ <- saveZones(testZones)
retrieved <- repo.listZones(superUserAuth)
retrieved <- repo.listZones(jdbcSuperUserAuth)
} yield retrieved
whenReady(f.unsafeToFuture(), timeout) { retrieved =>
@@ -431,7 +431,7 @@ class JdbcZoneRepositoryIntegrationSpec
val f =
for {
_ <- saveZones(testZones)
retrieved <- repo.listZones(superUserAuth, zoneNameFilter = Some("system"))
retrieved <- repo.listZones(jdbcSuperUserAuth, zoneNameFilter = Some("system"))
} yield retrieved
whenReady(f.unsafeToFuture(), timeout) { retrieved =>
@@ -471,19 +471,19 @@ class JdbcZoneRepositoryIntegrationSpec
whenReady(saveZones(testZones).unsafeToFuture(), timeout) { _ =>
whenReady(
repo.listZones(superUserAuth, offset = None, pageSize = 4).unsafeToFuture(),
repo.listZones(jdbcSuperUserAuth, offset = None, pageSize = 4).unsafeToFuture(),
timeout) { firstPage =>
(firstPage should contain).theSameElementsInOrderAs(expectedFirstPage)
}
whenReady(
repo.listZones(superUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(),
repo.listZones(jdbcSuperUserAuth, offset = Some(4), pageSize = 4).unsafeToFuture(),
timeout) { secondPage =>
(secondPage should contain).theSameElementsInOrderAs(expectedSecondPage)
}
whenReady(
repo.listZones(superUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(),
repo.listZones(jdbcSuperUserAuth, offset = Some(8), pageSize = 4).unsafeToFuture(),
timeout) { thirdPage =>
(thirdPage should contain).theSameElementsInOrderAs(expectedThirdPage)
}

View File

@@ -21,6 +21,7 @@ import java.util.UUID
import org.joda.time.DateTime
import vinyldns.core.domain.membership.GroupChangeType.GroupChangeType
import vinyldns.core.domain.membership.GroupStatus.GroupStatus
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership._
/* This is the new View model for Groups, do not surface the Group model directly any more */
@@ -73,7 +74,8 @@ case class UserInfo(
firstName: Option[String] = None,
lastName: Option[String] = None,
email: Option[String] = None,
created: Option[DateTime] = None
created: Option[DateTime] = None,
lockStatus: LockStatus = LockStatus.Unlocked
)
object UserInfo {
def apply(user: User): UserInfo =
@@ -83,7 +85,8 @@ object UserInfo {
firstName = user.firstName,
lastName = user.lastName,
email = user.email,
created = Some(user.created)
created = Some(user.created),
lockStatus = user.lockStatus
)
}

View File

@@ -19,6 +19,7 @@ package vinyldns.api.domain.membership
import cats.implicits._
import vinyldns.api.Interfaces._
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.zone.ZoneRepository
import vinyldns.core.domain.membership._
@@ -55,7 +56,7 @@ class MembershipService(
for {
existingGroup <- getExistingGroup(groupId)
newGroup = existingGroup.withUpdates(name, email, description, memberIds, adminUserIds)
_ <- isAdmin(existingGroup, authPrincipal).toResult
_ <- isGroupAdmin(existingGroup, authPrincipal).toResult
addedMembers = newGroup.memberIds.diff(existingGroup.memberIds)
removedMembers = existingGroup.memberIds.diff(newGroup.memberIds)
_ <- hasMembersAndAdmins(newGroup).toResult
@@ -72,7 +73,7 @@ class MembershipService(
def deleteGroup(groupId: String, authPrincipal: AuthPrincipal): Result[Group] =
for {
existingGroup <- getExistingGroup(groupId)
_ <- isAdmin(existingGroup, authPrincipal).toResult
_ <- isGroupAdmin(existingGroup, authPrincipal).toResult
_ <- groupCanBeDeleted(existingGroup)
_ <- groupChangeRepo
.save(GroupChange.forDelete(existingGroup, authPrincipal))
@@ -174,6 +175,12 @@ class MembershipService(
.getUsers(userIds, startFrom, pageSize)
.toResult[ListUsersResults]
def getExistingUser(userId: String): Result[User] =
userRepo
.getUser(userId)
.orFail(UserNotFoundError(s"User with ID $userId was not found"))
.toResult[User]
def getExistingGroup(groupId: String): Result[Group] =
groupRepo
.getGroup(groupId)
@@ -222,4 +229,15 @@ class MembershipService(
}
}
.toResult
def updateUserLockStatus(
userId: String,
lockStatus: LockStatus,
authPrincipal: AuthPrincipal): Result[User] =
for {
_ <- isSuperAdmin(authPrincipal).toResult
existingUser <- getExistingUser(userId)
newUser = existingUser.updateUserLockStatus(lockStatus)
_ <- userRepo.save(newUser).toResult[User]
} yield newUser
}

View File

@@ -18,6 +18,7 @@ package vinyldns.api.domain.membership
import vinyldns.api.Interfaces.Result
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership._
trait MembershipServiceAlgebra {
@@ -56,4 +57,9 @@ trait MembershipServiceAlgebra {
startFrom: Option[String],
maxItems: Int,
authPrincipal: AuthPrincipal): Result[ListGroupChangesResponse]
def updateUserLockStatus(
userId: String,
lockStatus: LockStatus,
authPrincipal: AuthPrincipal): Result[User]
}

View File

@@ -28,11 +28,16 @@ object MembershipValidations {
group.memberIds.nonEmpty && group.adminUserIds.nonEmpty
}
def isAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
def isGroupAdmin(group: Group, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) {
group.adminUserIds.contains(authPrincipal.userId) || authPrincipal.signedInUser.isSuper
}
def isSuperAdmin(authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) {
authPrincipal.signedInUser.isSuper
}
def canSeeGroup(groupId: String, authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError("Not authorized")) {
authPrincipal.isAuthorized(groupId)

View File

@@ -51,7 +51,20 @@ object TestDataLoader {
id = "dummy",
created = DateTime.now.secondOfDay().roundFloorCopy(),
accessKey = "dummyAccessKey",
secretKey = "dummySecretKey")
secretKey = "dummySecretKey"
)
final val lockedUser = User(
userName = "locked",
id = "locked",
created = DateTime.now.secondOfDay().roundFloorCopy(),
accessKey = "lockedAccessKey",
secretKey = "lockedSecretKey",
firstName = Some("Locked"),
lastName = Some("User"),
email = Some("testlocked@test.com"),
isSuper = false,
lockStatus = LockStatus.Locked
)
final val listOfDummyUsers: List[User] = List.range(0, 200).map { runner =>
User(
userName = "name-dummy%03d".format(runner),
@@ -117,7 +130,7 @@ object TestDataLoader {
)
def loadTestData(repository: UserRepository): IO[List[User]] =
(testUser :: okUser :: dummyUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser ::
(testUser :: okUser :: dummyUser :: lockedUser :: listGroupUser :: listZonesUser :: listBatchChangeSummariesUser ::
listZeroBatchChangeSummariesUser :: zoneHistoryUser :: listOfDummyUsers).map { user =>
val encrypted =
if (VinylDNSConfig.encryptUserSecrets)

View File

@@ -23,7 +23,7 @@ import cats.implicits._
import org.joda.time.DateTime
import org.json4s._
import vinyldns.api.domain.membership._
import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus}
import vinyldns.core.domain.membership.{Group, GroupChangeType, GroupStatus, LockStatus}
object MembershipJsonProtocol {
final case class CreateGroupInput(
@@ -52,6 +52,7 @@ trait MembershipJsonProtocol extends JsonValidation {
GroupChangeInfoSerializer,
CreateGroupInputSerializer,
UpdateGroupInputSerializer,
JsonEnumV(LockStatus),
JsonEnumV(GroupStatus),
JsonEnumV(GroupChangeType)
)

View File

@@ -23,7 +23,7 @@ import vinyldns.api.domain.membership._
import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.Group
import vinyldns.core.domain.membership.{Group, LockStatus}
trait MembershipRoute extends Directives {
this: VinylDNSJsonProtocol with VinylDNSDirectives with JsonValidationRejection =>
@@ -168,6 +168,18 @@ trait MembershipRoute extends Directives {
}
}
}
} ~
(put & path("users" / Segment / "lock") & monitor("Endpoint.lockUser")) { id =>
execute(membershipService.updateUserLockStatus(id, LockStatus.Locked, authPrincipal)) {
user =>
complete(StatusCodes.OK, UserInfo(user))
}
} ~
(put & path("users" / Segment / "unlock") & monitor("Endpoint.unlockUser")) { id =>
execute(membershipService.updateUserLockStatus(id, LockStatus.Unlocked, authPrincipal)) {
user =>
complete(StatusCodes.OK, UserInfo(user))
}
}
}

View File

@@ -17,8 +17,7 @@
package vinyldns.api.route
import akka.http.scaladsl.model.HttpRequest
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.{AuthenticationFailedRejection, RequestContext}
import akka.http.scaladsl.server.RequestContext
import cats.effect._
import cats.syntax.all._
import vinyldns.api.VinylDNSConfig
@@ -27,12 +26,14 @@ import vinyldns.api.domain.auth.{AuthPrincipalProvider, MembershipAuthPrincipalP
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.route.Monitored
import vinyldns.core.domain.membership.LockStatus
import scala.util.matching.Regex
sealed abstract class VinylDNSAuthenticationError(msg: String) extends Throwable(msg)
final case class AuthMissing(msg: String) extends VinylDNSAuthenticationError(msg)
final case class AuthRejected(reason: String) extends VinylDNSAuthenticationError(reason)
final case class AccountLocked(reason: String) extends VinylDNSAuthenticationError(reason)
trait VinylDNSAuthentication extends Monitored {
val authenticator: Aws4Authenticator
@@ -131,8 +132,14 @@ trait VinylDNSAuthentication extends Monitored {
if (encryptionEnabled) crypto.decrypt(str) else str
def getAuthPrincipal(accessKey: String): IO[AuthPrincipal] =
authPrincipalProvider.getAuthPrincipal(accessKey).map {
_.getOrElse(throw AuthRejected(s"Account with accessKey $accessKey specified was not found"))
authPrincipalProvider.getAuthPrincipal(accessKey).flatMap {
case Some(ok) =>
if (ok.signedInUser.lockStatus == LockStatus.Locked) {
IO.raiseError(
AccountLocked(s"Account with username ${ok.signedInUser.userName} is locked"))
} else IO.pure(ok)
case None =>
IO.raiseError(AuthRejected(s"Account with accessKey $accessKey specified was not found"))
}
}
@@ -141,16 +148,14 @@ class VinylDNSAuthenticator(
val authPrincipalProvider: AuthPrincipalProvider)
extends VinylDNSAuthentication {
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] =
authenticate(ctx, content).attempt.map {
case Right(ok) => Right(ok)
case Left(_: AuthMissing) =>
Left(AuthenticationFailedRejection.CredentialsMissing)
case Left(_: AuthRejected) =>
Left(AuthenticationFailedRejection.CredentialsRejected)
case Left(e: Throwable) =>
// throw here as some unexpected exception occurred
throw e
def apply(
ctx: RequestContext,
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
// Need to refactor authenticate to be an IO[Either[E, A]] instead of how it is implemented, for the time being...
authenticate(ctx, content).attempt.flatMap {
case Left(e: VinylDNSAuthenticationError) => IO.pure(Left(e))
case Right(ok) => IO.pure(Right(ok))
case Left(e) => IO.raiseError(e)
}
}
@@ -159,6 +164,8 @@ object VinylDNSAuthenticator {
lazy val authPrincipalProvider = MembershipAuthPrincipalProvider()
lazy val authenticator = new VinylDNSAuthenticator(aws4Authenticator, authPrincipalProvider)
def apply(ctx: RequestContext, content: String): IO[Either[Cause, AuthPrincipal]] =
def apply(
ctx: RequestContext,
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
authenticator.apply(ctx, content)
}

View File

@@ -17,7 +17,6 @@
package vinyldns.api.route
import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.RouteResult.{Complete, Rejected}
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.directives.BasicDirectives
@@ -42,7 +41,7 @@ trait VinylDNSDirectives extends Directives {
*/
def vinyldnsAuthenticator(
ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] =
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
VinylDNSAuthenticator(ctx, content)
def authenticate: Directive1[AuthPrincipal] =
@@ -53,19 +52,27 @@ trait VinylDNSDirectives extends Directives {
.flatMap {
case Right(authPrincipal)
provide(authPrincipal)
case Left(cause)
// we need to finish the result, rejections will proceed and ultimately
// we can fail with a different rejection
complete(
HttpResponse(
status = StatusCodes.Unauthorized,
entity = HttpEntity(s"Authentication Failed: $cause")
))
case Left(e)
complete(handleAuthenticateError(e))
}
}
}
}
def handleAuthenticateError(error: VinylDNSAuthenticationError): HttpResponse =
error match {
case AccountLocked(err) =>
HttpResponse(
status = StatusCodes.Forbidden,
entity = HttpEntity(s"Authentication Failed: $err")
)
case e =>
HttpResponse(
status = StatusCodes.Unauthorized,
entity = HttpEntity(s"Authentication Failed: ${e.getMessage}")
)
}
/* Adds monitoring to an Endpoint. The name will be surfaced in JMX */
def monitor(name: String): Directive0 =
extractExecutionContext.flatMap { implicit ec

View File

@@ -30,8 +30,9 @@ trait GroupTestData { this: Matchers =>
val okUser: User = TestDataLoader.okUser
val dummyUser: User = TestDataLoader.dummyUser
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers
val lockedUser: User = TestDataLoader.lockedUser
val listOfDummyUsers: List[User] = TestDataLoader.listOfDummyUsers
val okUserInfo: UserInfo = UserInfo(okUser)
val dummyUserInfo: UserInfo = UserInfo(dummyUser)
@@ -91,6 +92,7 @@ trait GroupTestData { this: Matchers =>
val noGroupsUserAuth: AuthPrincipal = AuthPrincipal(okUser, Seq())
val deletedGroupAuth: AuthPrincipal = AuthPrincipal(okUser, Seq(deletedGroup.id))
val dummyUserAuth: AuthPrincipal = AuthPrincipal(dummyUser, Seq(dummyGroup.id))
val lockedUserAuth: AuthPrincipal = AuthPrincipal(lockedUser, Seq())
val listOfDummyGroupsAuth: AuthPrincipal = AuthPrincipal(dummyUser, listOfDummyGroups.map(_.id))
val memberOkZoneAuthorized: Zone = Zone(

View File

@@ -39,6 +39,7 @@ trait VinylDNSTestData {
created = DateTime.now.secondOfDay().roundFloorCopy())
val okAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.okUser, Seq(grp.id))
val notAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.dummyUser, Seq.empty)
val lockedAuth: AuthPrincipal = AuthPrincipal(TestDataLoader.lockedUser, Seq.empty)
val testConnection: Option[ZoneConnection] = Some(
ZoneConnection("vinyldns.", "vinyldns.", "nzisn+4G2ldMn0q1CV3vsg==", "10.1.1.1"))

View File

@@ -23,7 +23,7 @@ import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
import vinyldns.api.Interfaces._
import vinyldns.api.{GroupTestData, ResultHelpers}
import vinyldns.api.{GroupTestData, ResultHelpers, VinylDNSTestData}
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.zone.{ZoneRepository, _}
import cats.effect._
@@ -37,6 +37,7 @@ class MembershipServiceSpec
with BeforeAndAfterEach
with ResultHelpers
with GroupTestData
with VinylDNSTestData
with EitherMatchers {
private val mockGroupRepo = mock[GroupRepository]
@@ -750,5 +751,72 @@ class MembershipServiceSpec
error shouldBe a[InvalidGroupRequestError]
}
}
"updateUserLockStatus" should {
"save the update and lock the user account" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(Some(okUser))).when(mockUserRepo).getUser(okUser.id)
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
.value
.unsafeRunSync()
val userCaptor = ArgumentCaptor.forClass(classOf[User])
verify(mockUserRepo).save(userCaptor.capture())
val savedUser = userCaptor.getValue
savedUser.lockStatus shouldBe LockStatus.Locked
savedUser.id shouldBe okUser.id
}
"save the update and unlock the user account" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(Some(lockedUser))).when(mockUserRepo).getUser(lockedUser.id)
doReturn(IO.pure(okUser)).when(mockUserRepo).save(any[User])
underTest
.updateUserLockStatus(lockedUser.id, LockStatus.Unlocked, superUserAuth)
.value
.unsafeRunSync()
val userCaptor = ArgumentCaptor.forClass(classOf[User])
verify(mockUserRepo).save(userCaptor.capture())
val savedUser = userCaptor.getValue
savedUser.lockStatus shouldBe LockStatus.Unlocked
savedUser.id shouldBe lockedUser.id
}
"return an error if the signed in user is not a super user" in {
val error = leftResultOf(
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, dummyUserAuth)
.value)
error shouldBe a[NotAuthorizedError]
}
"return an error if the requested user is not found" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(IO.pure(None)).when(mockUserRepo).getUser(okUser.id)
val error = leftResultOf(
underTest
.updateUserLockStatus(okUser.id, LockStatus.Locked, superUserAuth)
.value)
error shouldBe a[UserNotFoundError]
}
}
}
}

View File

@@ -56,17 +56,17 @@ class MembershipValidationsSpec
"isAdmin" should {
"return true when the user is in admin group" in {
isAdmin(okGroup, okUserAuth) should be(right)
isGroupAdmin(okGroup, okUserAuth) should be(right)
}
"return true when the user is a super user" in {
val user = User("some", "new", "user", isSuper = true)
val superAuth = AuthPrincipal(user, Seq())
isAdmin(okGroup, superAuth) should be(right)
isGroupAdmin(okGroup, superAuth) should be(right)
}
"return an error when the user has no access and is not super" in {
val user = User("some", "new", "user")
val nonSuperAuth = AuthPrincipal(user, Seq())
val error = leftValue(isAdmin(okGroup, nonSuperAuth))
val error = leftValue(isGroupAdmin(okGroup, nonSuperAuth))
error shouldBe an[NotAuthorizedError]
}
}

View File

@@ -30,10 +30,11 @@ import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
import vinyldns.api.Interfaces._
import vinyldns.api.domain.membership._
import vinyldns.core.domain.auth.AuthPrincipal
import vinyldns.core.domain.membership.Group
import vinyldns.core.domain.membership.{Group, LockStatus}
import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.api.route.MembershipJsonProtocol.{CreateGroupInput, UpdateGroupInput}
import vinyldns.api.{GroupTestData, VinylDNSTestData}
import vinyldns.core.domain.membership.LockStatus.LockStatus
class MembershipRoutingSpec
extends WordSpec
@@ -671,4 +672,64 @@ class MembershipRoutingSpec
}
}
}
"PUT update user lock status" should {
"return a 200 response with the user locked" in {
val updatedUser = okUser.copy(lockStatus = LockStatus.Locked)
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(updatedUser))
.when(membershipService)
.updateUserLockStatus("ok", LockStatus.Locked, superUserAuth)
Put("/users/ok/lock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.OK
val result = responseAs[UserInfo]
result.id shouldBe okUser.id
result.lockStatus shouldBe LockStatus.Locked
}
}
"return a 200 response with the user unlocked" in {
val updatedUser = lockedUser.copy(lockStatus = LockStatus.Unlocked)
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(updatedUser))
.when(membershipService)
.updateUserLockStatus("locked", LockStatus.Unlocked, superUserAuth)
Put("/users/locked/unlock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.OK
val result = responseAs[UserInfo]
result.id shouldBe lockedUser.id
result.lockStatus shouldBe LockStatus.Unlocked
}
}
"return a 404 Not Found when the user is not found" in {
val superUserAuth = okAuth.copy(
signedInUser = dummyUserAuth.signedInUser.copy(isSuper = true),
memberGroupIds = Seq.empty)
doReturn(result(UserNotFoundError("fail")))
.when(membershipService)
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
Put("/users/notFound/lock") ~> membershipRoute(superUserAuth) ~> check {
status shouldBe StatusCodes.NotFound
}
}
"return a 403 Forbidden when not authorized" in {
doReturn(result(NotAuthorizedError("fail")))
.when(membershipService)
.updateUserLockStatus(anyString, any[LockStatus], any[AuthPrincipal])
Put("/users/forbidden/lock") ~> membershipRoute(okGroupAuth) ~> check {
status shouldBe StatusCodes.Forbidden
}
}
}
}

View File

@@ -17,7 +17,6 @@
package vinyldns.api.route
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest, StatusCodes}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.{Directives, RequestContext, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect._
@@ -485,7 +484,7 @@ class RecordSetRoutingSpec
override def vinyldnsAuthenticator(
ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] =
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
IO.pure(Right(okAuth))
private def rsJson(recordSet: RecordSet): String =

View File

@@ -17,11 +17,6 @@
package vinyldns.api.route
import akka.http.scaladsl.model.{HttpHeader, HttpRequest}
import akka.http.scaladsl.server.AuthenticationFailedRejection.{
Cause,
CredentialsMissing,
CredentialsRejected
}
import akka.http.scaladsl.server.RequestContext
import cats.effect._
import org.mockito.Matchers._
@@ -29,15 +24,13 @@ import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{Matchers, WordSpec}
import vinyldns.api.domain.auth.AuthPrincipalProvider
import vinyldns.api.{GroupTestData, ResultHelpers}
import vinyldns.api.{GroupTestData}
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.auth.AuthPrincipal
class VinylDNSAuthenticatorSpec
extends WordSpec
with Matchers
with MockitoSugar
with ResultHelpers
with GroupTestData {
private val mockAuthenticator = mock[Aws4Authenticator]
private val mockAuthPrincipalProvider = mock[AuthPrincipalProvider]
@@ -87,8 +80,7 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Right(okUserAuth)
}
"fail if missing Authorization header" in {
@@ -109,9 +101,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsMissing)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthMissing("Authorization header not found"))
}
"fail if Authorization header can not be parsed" in {
val fakeHttpHeader = mock[HttpHeader]
@@ -125,9 +116,8 @@ class VinylDNSAuthenticatorSpec
val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsRejected)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthRejected("Authorization header could not be parsed"))
}
"fail if the access key is missing" in {
val fakeHttpHeader = mock[HttpHeader]
@@ -149,9 +139,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator)
.extractAccessKey(any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsMissing)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthMissing("accessKey not found"))
}
"fail if the access key can not be retrieved" in {
val fakeHttpHeader = mock[HttpHeader]
@@ -173,9 +162,34 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator)
.extractAccessKey(any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsRejected)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthRejected("Invalid authorization header"))
}
"fail if the user is locked" in {
val fakeHttpHeader = mock[HttpHeader]
doReturn("Authorization").when(fakeHttpHeader).name
val header = "AWS4-HMAC-SHA256" +
" Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request," +
" SignedHeaders=host;range;x-amz-date," +
" Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
doReturn(header).when(fakeHttpHeader).value
val httpRequest: HttpRequest = HttpRequest().withHeaders(List(fakeHttpHeader))
val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request
doReturn(lockedUser.accessKey)
.when(mockAuthenticator)
.extractAccessKey(any[String])
doReturn(IO.pure(Some(lockedUserAuth)))
.when(mockAuthPrincipalProvider)
.getAuthPrincipal(any[String])
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AccountLocked("Account with username locked is locked"))
}
"fail if the user can not be found" in {
val fakeHttpHeader = mock[HttpHeader]
@@ -192,7 +206,7 @@ class VinylDNSAuthenticatorSpec
val context: RequestContext = mock[RequestContext]
doReturn(httpRequest).when(context).request
doReturn(okUser.accessKey)
doReturn("fakeKey")
.when(mockAuthenticator)
.extractAccessKey(any[String])
@@ -201,9 +215,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthPrincipalProvider)
.getAuthPrincipal(any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsRejected)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthRejected("Account with accessKey fakeKey specified was not found"))
}
"fail if signatures can not be validated" in {
val fakeHttpHeader = mock[HttpHeader]
@@ -233,9 +246,8 @@ class VinylDNSAuthenticatorSpec
.when(mockAuthenticator)
.authenticateReq(any[HttpRequest], any[List[String]], any[String], any[String])
val result =
await[Either[Cause, AuthPrincipal]](underTest.apply(context, ""))
result shouldBe Left(CredentialsRejected)
val result = underTest.apply(context, "").unsafeRunSync()
result shouldBe Left(AuthRejected("Request signature could not be validated"))
}
}
}

View File

@@ -18,7 +18,7 @@ package vinyldns.api.route
import java.io.IOException
import akka.http.scaladsl.model.{HttpResponse, StatusCodes}
import akka.http.scaladsl.model.{HttpEntity, HttpResponse, StatusCodes}
import akka.http.scaladsl.server.{Directives, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest
import nl.grons.metrics.scala.{Histogram, Meter}
@@ -66,6 +66,26 @@ class VinylDNSDirectivesSpec
override def beforeEach(): Unit =
reset(mockLatency, mockErrors)
".handleAuthenticateError" should {
"respond with Forbidden status if account is locked" in {
val trythis = handleAuthenticateError(AccountLocked("error"))
trythis shouldBe HttpResponse(
status = StatusCodes.Forbidden,
entity = HttpEntity(s"Authentication Failed: error")
)
}
"respond with Unauthorized status for other authentication errors" in {
val trythis = handleAuthenticateError(AuthRejected("error"))
trythis shouldBe HttpResponse(
status = StatusCodes.Unauthorized,
entity = HttpEntity(s"Authentication Failed: error")
)
}
}
"The monitor directive" should {
"record when completing an HttpResponse normally" in {
Get("/test") ~> testRoute ~> check {

View File

@@ -19,7 +19,6 @@ package vinyldns.api.route
import akka.actor.ActorSystem
import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model.{ContentTypes, HttpEntity, HttpRequest}
import akka.http.scaladsl.server.AuthenticationFailedRejection.Cause
import akka.http.scaladsl.server.{Directives, RequestContext, Route}
import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect._
@@ -332,7 +331,7 @@ class ZoneRoutingSpec
override def vinyldnsAuthenticator(
ctx: RequestContext,
content: String): IO[Either[Cause, AuthPrincipal]] =
content: String): IO[Either[VinylDNSAuthenticationError, AuthPrincipal]] =
IO.pure(Right(okAuth))
def zoneJson(name: String, email: String): String =

View File

@@ -19,6 +19,12 @@ package vinyldns.core.domain.membership
import java.util.UUID
import org.joda.time.DateTime
import vinyldns.core.domain.membership.LockStatus.LockStatus
object LockStatus extends Enumeration {
type LockStatus = Value
val Locked, Unlocked = Value
}
case class User(
userName: String,
@@ -29,5 +35,10 @@ case class User(
email: Option[String] = None,
created: DateTime = DateTime.now,
id: String = UUID.randomUUID().toString,
isSuper: Boolean = false
)
isSuper: Boolean = false,
lockStatus: LockStatus = LockStatus.Unlocked
) {
def updateUserLockStatus(lockStatus: LockStatus): User =
this.copy(lockStatus = lockStatus)
}

View File

@@ -20,8 +20,7 @@ import cats.implicits._
import com.amazonaws.services.dynamodbv2.model.DeleteTableRequest
import com.typesafe.config.ConfigFactory
import vinyldns.core.crypto.NoOpCrypto
import vinyldns.core.domain.membership.User
import vinyldns.core.domain.membership.{User, LockStatus}
import scala.concurrent.duration._
class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
@@ -135,5 +134,24 @@ class DynamoDBUserRepositoryIntegrationSpec extends DynamoDBIntegrationSpec {
result shouldBe Some(testUser)
result.get.isSuper shouldBe false
}
"returns the locked flag when true" in {
val testUser = User(
userName = "testSuper",
accessKey = "testSuper",
secretKey = "testUser",
lockStatus = LockStatus.Locked)
val saved = repo.save(testUser).unsafeRunSync()
val result = repo.getUser(saved.id).unsafeRunSync()
result shouldBe Some(testUser)
result.get.lockStatus shouldBe LockStatus.Locked
}
"returns the locked flag when false" in {
val f = repo.getUserByAccessKey(users.head.accessKey).unsafeRunSync()
f shouldBe Some(users.head)
f.get.lockStatus shouldBe LockStatus.Unlocked
}
}
}

View File

@@ -26,10 +26,12 @@ import com.amazonaws.services.dynamodbv2.model._
import org.joda.time.DateTime
import org.slf4j.{Logger, LoggerFactory}
import vinyldns.core.crypto.CryptoAlgebra
import vinyldns.core.domain.membership.{ListUsersResults, User, UserRepository}
import vinyldns.core.domain.membership.LockStatus.LockStatus
import vinyldns.core.domain.membership.{ListUsersResults, LockStatus, User, UserRepository}
import vinyldns.core.route.Monitored
import scala.collection.JavaConverters._
import scala.util.Try
object DynamoDBUserRepository {
@@ -42,6 +44,7 @@ object DynamoDBUserRepository {
private[repository] val ACCESS_KEY = "accesskey"
private[repository] val SECRET_KEY = "secretkey"
private[repository] val IS_SUPER = "super"
private[repository] val LOCK_STATUS = "lockstatus"
private[repository] val USER_NAME_INDEX_NAME = "username_index"
private[repository] val ACCESS_KEY_INDEX_NAME = "access_key_index"
@@ -97,6 +100,7 @@ object DynamoDBUserRepository {
item.put(ACCESS_KEY, new AttributeValue(user.accessKey))
item.put(SECRET_KEY, new AttributeValue(crypto.encrypt(user.secretKey)))
item.put(IS_SUPER, new AttributeValue().withBOOL(user.isSuper))
item.put(LOCK_STATUS, new AttributeValue(user.lockStatus.toString))
val firstName =
user.firstName.map(new AttributeValue(_)).getOrElse(new AttributeValue().withNULL(true))
@@ -110,6 +114,12 @@ object DynamoDBUserRepository {
}
def fromItem(item: java.util.Map[String, AttributeValue]): IO[User] = IO {
def userStatus(str: String): LockStatus = Try(LockStatus.withName(str)).getOrElse {
val log: Logger = LoggerFactory.getLogger(classOf[DynamoDBUserRepository])
log.error(s"Invalid locked status value '$str'; defaulting to unlocked")
LockStatus.Unlocked
}
User(
id = item.get(USER_ID).getS,
userName = item.get(USER_NAME).getS,
@@ -119,7 +129,8 @@ object DynamoDBUserRepository {
firstName = Option(item.get(FIRST_NAME)).flatMap(fn => Option(fn.getS)),
lastName = Option(item.get(LAST_NAME)).flatMap(ln => Option(ln.getS)),
email = Option(item.get(EMAIL)).flatMap(e => Option(e.getS)),
isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL
isSuper = if (item.get(IS_SUPER) == null) false else item.get(IS_SUPER).getBOOL,
lockStatus = userStatus(item.get(LOCK_STATUS).getS)
)
}
}

View File

@@ -31,6 +31,7 @@ import scala.collection.JavaConverters._
import cats.effect._
import com.typesafe.config.ConfigFactory
import vinyldns.core.crypto.{CryptoAlgebra, NoOpCrypto}
import vinyldns.core.domain.membership.LockStatus
import vinyldns.dynamodb.DynamoTestConfig
class DynamoDBUserRepositorySpec
@@ -72,6 +73,7 @@ class DynamoDBUserRepositorySpec
items.get(LAST_NAME).getS shouldBe okUser.lastName.get
items.get(EMAIL).getS shouldBe okUser.email.get
items.get(CREATED).getN shouldBe okUser.created.getMillis.toString
items.get(LOCK_STATUS).getS shouldBe okUser.lockStatus.toString
}
"set the first name to null if it is not present" in {
val emptyFirstName = okUser.copy(firstName = None)
@@ -131,6 +133,7 @@ class DynamoDBUserRepositorySpec
item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accessKey"))
item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("lockstatus"))
val user = fromItem(item).unsafeRunSync()
user.firstName shouldBe None
@@ -151,10 +154,24 @@ class DynamoDBUserRepositorySpec
item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accesskey"))
item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("Locked"))
val user = fromItem(item).unsafeRunSync()
user.isSuper shouldBe false
}
"sets the lockStatus to Unlocked if the given value is invalid" in {
val item = new java.util.HashMap[String, AttributeValue]()
item.put(USER_ID, new AttributeValue("ok"))
item.put(USER_NAME, new AttributeValue("ok"))
item.put(CREATED, new AttributeValue().withN("0"))
item.put(ACCESS_KEY, new AttributeValue("accesskey"))
item.put(SECRET_KEY, new AttributeValue("secretkey"))
item.put(LOCK_STATUS, new AttributeValue("lock_status"))
val user = fromItem(item).unsafeRunSync()
user.lockStatus shouldBe LockStatus.Unlocked
}
}
"DynamoDBUserRepository.getUser" should {