2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-08-22 10:10:12 +00:00

Merge pull request #1166 from Aravindh-Raju/aravindhr/add-auth-for-status

Add auth on POST /status
This commit is contained in:
Nicholas Spadaccino 2022-11-08 12:48:40 -05:00 committed by GitHub
commit 1c2c7b3424
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 167 additions and 43 deletions

View File

@ -17,10 +17,15 @@
package vinyldns.api.route package vinyldns.api.route
import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Directives import akka.http.scaladsl.server.Route
import akka.util.Timeout import akka.util.Timeout
import cats.effect.IO import cats.effect.IO
import fs2.concurrent.SignallingRef import fs2.concurrent.SignallingRef
import org.slf4j.{Logger, LoggerFactory}
import vinyldns.api.Interfaces.{EitherImprovements, Result, ensuring}
import vinyldns.api.config.ServerConfig
import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.core.domain.auth.AuthPrincipal
import scala.concurrent.duration._ import scala.concurrent.duration._
@ -31,39 +36,63 @@ final case class CurrentStatus(
version: String version: String
) )
trait StatusRoute extends Directives { class StatusRoute(
this: VinylDNSJsonProtocol => serverConfig: ServerConfig,
val vinylDNSAuthenticator: VinylDNSAuthenticator,
val processingDisabled: SignallingRef[IO, Boolean]
) extends VinylDNSJsonProtocol
with VinylDNSDirectives[Throwable] {
implicit val timeout = Timeout(10.seconds) def getRoutes: Route = statusRoute
def processingDisabled: SignallingRef[IO, Boolean] implicit val timeout: Timeout = Timeout(10.seconds)
def statusRoute(color: String, version: String, keyName: String) = def logger: Logger = LoggerFactory.getLogger(classOf[StatusRoute])
def handleErrors(e: Throwable): PartialFunction[Throwable, Route] = {
case NotAuthorizedError(msg) => complete(StatusCodes.Forbidden, msg)
}
def postStatus(isProcessingDisabled: Boolean, authPrincipal: AuthPrincipal): Result[Boolean] = {
for {
_ <- isAdmin(authPrincipal).toResult
isDisabled = isProcessingDisabled
} yield isDisabled
}
def isAdmin(authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError(s"Not authorized. User '${authPrincipal.signedInUser.userName}' cannot make the requested change.")) {
authPrincipal.isSystemAdmin
}
val statusRoute: Route =
(get & path("status")) { (get & path("status")) {
onSuccess(processingDisabled.get.unsafeToFuture()) { isProcessingDisabled => onSuccess(processingDisabled.get.unsafeToFuture()) { isProcessingDisabled =>
complete( complete(
StatusCodes.OK, StatusCodes.OK,
CurrentStatus( CurrentStatus(
isProcessingDisabled, isProcessingDisabled,
color, serverConfig.color,
keyName, serverConfig.keyName,
version serverConfig.version
) )
) )
} }
} ~ } ~
(post & path("status")) { (post & path("status")) {
parameters("processingDisabled".as[Boolean]) { isProcessingDisabled => parameters("processingDisabled".as[Boolean]) { isProcessingDisabled =>
onSuccess(processingDisabled.set(isProcessingDisabled).unsafeToFuture()) { authenticateAndExecute(postStatus(isProcessingDisabled, _)){ isProcessingDisabled =>
complete( onSuccess(processingDisabled.set(isProcessingDisabled).unsafeToFuture()) {
StatusCodes.OK, complete(
CurrentStatus( StatusCodes.OK,
isProcessingDisabled, CurrentStatus(
color, isProcessingDisabled,
keyName, serverConfig.color,
version serverConfig.keyName,
serverConfig.version
)
) )
) }
} }
} }
} }

View File

@ -69,7 +69,6 @@ class VinylDNSService(
) extends PingRoute ) extends PingRoute
with HealthCheckRoute with HealthCheckRoute
with BlueGreenRoute with BlueGreenRoute
with StatusRoute
with PrometheusRoute with PrometheusRoute
with VinylDNSJsonProtocol with VinylDNSJsonProtocol
with RequestLogging { with RequestLogging {
@ -97,27 +96,29 @@ class VinylDNSService(
vinylDNSAuthenticator, vinylDNSAuthenticator,
vinyldnsConfig.manualReviewConfig vinyldnsConfig.manualReviewConfig
).getRoutes ).getRoutes
val statusRoute: Route =
new StatusRoute(
vinyldnsConfig.serverConfig,
vinylDNSAuthenticator,
processingDisabled
).getRoutes
val unloggedUris = Seq( val unloggedUris = Seq(
Uri.Path("/health"), Uri.Path("/health"),
Uri.Path("/color"), Uri.Path("/color"),
Uri.Path("/ping"), Uri.Path("/ping"),
Uri.Path("/status"),
Uri.Path("/metrics/prometheus") Uri.Path("/metrics/prometheus")
) )
val unloggedRoutes: Route = healthCheckRoute ~ pingRoute ~ colorRoute( val unloggedRoutes: Route = healthCheckRoute ~ pingRoute ~ colorRoute(
vinyldnsConfig.serverConfig.color vinyldnsConfig.serverConfig.color
) ~ statusRoute(
vinyldnsConfig.serverConfig.color,
vinyldnsConfig.serverConfig.version,
vinyldnsConfig.serverConfig.keyName
) ~ prometheusRoute ) ~ prometheusRoute
val allRoutes: Route = unloggedRoutes ~ val allRoutes: Route = unloggedRoutes ~
batchChangeRoute ~ batchChangeRoute ~
zoneRoute ~ zoneRoute ~
recordSetRoute ~ recordSetRoute ~
membershipRoute membershipRoute ~
statusRoute
val vinyldnsRoutes: Route = logRequestResult(requestLogger(unloggedUris))(allRoutes) val vinyldnsRoutes: Route = logRequestResult(requestLogger(unloggedUris))(allRoutes)

View File

@ -18,6 +18,43 @@ def test_get_status_success(shared_zone_test_context):
assert_that(result["version"], not_none()) assert_that(result["version"], not_none())
def test_post_status_fails_for_non_admin_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint fails for non-admin users
"""
client = shared_zone_test_context.ok_vinyldns_client
result = client.post_status(True)
assert_that(result, is_("Not authorized. User 'ok' cannot make the requested change."))
def test_post_status_fails_for_non_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint fails for non-users with fake access and secret key
"""
client = shared_zone_test_context.non_user_client
result = client.post_status(True)
assert_that(result, is_("Authentication Failed: Account with accessKey not-exist-key specified was not found"))
def test_post_status_pass_for_admin_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint pass for admin users
"""
client = shared_zone_test_context.support_user_client
client.post_status(True)
status = client.get_status()
assert_that(status["processingDisabled"], is_(True))
client.post_status(False)
status = client.get_status()
assert_that(status["processingDisabled"], is_(False))
@pytest.mark.serial @pytest.mark.serial
@pytest.mark.skip_production @pytest.mark.skip_production
def test_toggle_processing(shared_zone_test_context): def test_toggle_processing(shared_zone_test_context):
@ -25,15 +62,16 @@ def test_toggle_processing(shared_zone_test_context):
Test that updating a zone when processing is disabled does not happen Test that updating a zone when processing is disabled does not happen
""" """
client = shared_zone_test_context.ok_vinyldns_client client = shared_zone_test_context.ok_vinyldns_client
admin_client = shared_zone_test_context.support_user_client
ok_zone = copy.deepcopy(shared_zone_test_context.ok_zone) ok_zone = copy.deepcopy(shared_zone_test_context.ok_zone)
# disable processing # disable processing
client.post_status(True) admin_client.post_status(True)
status = client.get_status() status = client.get_status()
assert_that(status["processingDisabled"], is_(True)) assert_that(status["processingDisabled"], is_(True))
client.post_status(False) admin_client.post_status(False)
status = client.get_status() status = client.get_status()
assert_that(status["processingDisabled"], is_(False)) assert_that(status["processingDisabled"], is_(False))

View File

@ -31,8 +31,9 @@ class SharedZoneTestContext(object):
self.unassociated_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "listGroupAccessKey", "listGroupSecretKey") self.unassociated_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "listGroupAccessKey", "listGroupSecretKey")
self.test_user_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "testUserAccessKey", "testUserSecretKey") self.test_user_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "testUserAccessKey", "testUserSecretKey")
self.history_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "history-key", "history-secret") self.history_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "history-key", "history-secret")
self.non_user_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "not-exist-key", "not-exist-secret")
self.clients = [self.ok_vinyldns_client, self.dummy_vinyldns_client, self.shared_zone_vinyldns_client, self.support_user_client, self.clients = [self.ok_vinyldns_client, self.dummy_vinyldns_client, self.shared_zone_vinyldns_client, self.support_user_client,
self.unassociated_client, self.test_user_client, self.history_client] self.unassociated_client, self.test_user_client, self.history_client, self.non_user_client]
self.list_zones = ListZonesTestContext(partition_id) self.list_zones = ListZonesTestContext(partition_id)
self.list_zones_client = self.list_zones.client self.list_zones_client = self.list_zones.client
self.list_records_context = ListRecordSetsTestContext(partition_id) self.list_records_context = ListRecordSetsTestContext(partition_id)

View File

@ -16,9 +16,11 @@
package vinyldns.api package vinyldns.api
import cats.effect.{ContextShift, IO}
import com.comcast.ip4s.IpAddress import com.comcast.ip4s.IpAddress
import fs2.concurrent.SignallingRef
import org.joda.time.DateTime import org.joda.time.DateTime
import vinyldns.api.config.{ZoneAuthConfigs, BatchChangeConfig, DottedHostsConfig, HighValueDomainConfig, LimitsConfig, ManualReviewConfig, ScheduledChangesConfig} import vinyldns.api.config.{ZoneAuthConfigs, DottedHostsConfig, BatchChangeConfig, HighValueDomainConfig, LimitsConfig, ManualReviewConfig, ScheduledChangesConfig, ServerConfig}
import vinyldns.api.domain.batch.V6DiscoveryNibbleBoundaries import vinyldns.api.domain.batch.V6DiscoveryNibbleBoundaries
import vinyldns.core.domain.record._ import vinyldns.core.domain.record._
import vinyldns.core.domain.zone._ import vinyldns.core.domain.zone._
@ -27,6 +29,12 @@ import scala.util.matching.Regex
trait VinylDNSTestHelpers { trait VinylDNSTestHelpers {
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
val processingDisabled: SignallingRef[IO, Boolean] =
fs2.concurrent.SignallingRef[IO, Boolean](false).unsafeRunSync()
val highValueDomainRegexList: List[Regex] = List(new Regex("high-value-domain.*")) val highValueDomainRegexList: List[Regex] = List(new Regex("high-value-domain.*"))
val highValueDomainIpList: List[IpAddress] = val highValueDomainIpList: List[IpAddress] =
(IpAddress("192.0.2.252") ++ IpAddress("192.0.2.253") ++ IpAddress( (IpAddress("192.0.2.252") ++ IpAddress("192.0.2.253") ++ IpAddress(
@ -75,6 +83,9 @@ trait VinylDNSTestHelpers {
val testLimitConfig: LimitsConfig = val testLimitConfig: LimitsConfig =
LimitsConfig(100,100,1000,1500,100,100,100) LimitsConfig(100,100,1000,1500,100,100,100)
val testServerConfig: ServerConfig =
ServerConfig(100, 100, 100, 100, true, approvedNameServers, "blue", "unset", "vinyldns.", false, true, true)
val batchChangeConfig: BatchChangeConfig = val batchChangeConfig: BatchChangeConfig =
BatchChangeConfig(batchChangeLimit, sharedApprovedTypes, v6DiscoveryNibbleBoundaries) BatchChangeConfig(batchChangeLimit, sharedApprovedTypes, v6DiscoveryNibbleBoundaries)

View File

@ -18,35 +18,49 @@ package vinyldns.api.route
import akka.actor.ActorSystem import akka.actor.ActorSystem
import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.testkit.ScalatestRouteTest import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect.{ContextShift, IO}
import fs2.concurrent.SignallingRef
import org.scalatest._ import org.scalatest._
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec import org.scalatest.wordspec.AnyWordSpec
import org.scalatestplus.mockito.MockitoSugar import org.scalatestplus.mockito.MockitoSugar
import vinyldns.api.VinylDNSTestHelpers
import vinyldns.api.VinylDNSTestHelpers.processingDisabled
import vinyldns.core.TestMembershipData.{notAuth, okAuth, superUserAuth}
class StatusRoutingSpec class StatusRoutingSpec
extends AnyWordSpec extends AnyWordSpec
with ScalatestRouteTest with ScalatestRouteTest
with StatusRoute
with OneInstancePerTest with OneInstancePerTest
with VinylDNSJsonProtocol with VinylDNSJsonProtocol
with BeforeAndAfterEach with BeforeAndAfterEach
with MockitoSugar with MockitoSugar
with Matchers { with Matchers {
val statusRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(okAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
val notAuthRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(notAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
val adminUserRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(superUserAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
def actorRefFactory: ActorSystem = system def actorRefFactory: ActorSystem = system
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
val processingDisabled: SignallingRef[IO, Boolean] =
fs2.concurrent.SignallingRef[IO, Boolean](false).unsafeRunSync()
"GET /status" should { "GET /status" should {
"return the current status of true" in { "return the current status of true" in {
Get("/status") ~> statusRoute("blue", "unset", "vinyldns.") ~> check { Get("/status") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.OK response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus] val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe false resultStatus.processingDisabled shouldBe false
@ -58,16 +72,16 @@ class StatusRoutingSpec
} }
"POST /status" should { "POST /status" should {
"disable processing" in { "disable processing when it's requested by admin user" in {
Post("/status?processingDisabled=true") ~> statusRoute("blue", "unset", "vinyldns.") ~> check { Post("/status?processingDisabled=true") ~> adminUserRoute ~> check {
response.status shouldBe StatusCodes.OK response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus] val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe true resultStatus.processingDisabled shouldBe true
} }
} }
"enable processing" in { "enable processing when it's requested by admin user" in {
Post("/status?processingDisabled=false") ~> statusRoute("blue", "unset", "vinyldns.") ~> check { Post("/status?processingDisabled=false") ~> adminUserRoute ~> check {
response.status shouldBe StatusCodes.OK response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus] val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe false resultStatus.processingDisabled shouldBe false
@ -76,5 +90,35 @@ class StatusRoutingSpec
processingDisabled.get.unsafeRunSync() shouldBe false processingDisabled.get.unsafeRunSync() shouldBe false
} }
} }
"not disable processing when it's requested by non-admin user" in {
Post("/status?processingDisabled=true") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
}
}
"not enable processing when it's requested by non-admin user" in {
Post("/status?processingDisabled=false") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
// remember, the signal is the opposite of intent
processingDisabled.get.unsafeRunSync() shouldBe false
}
}
"not disable processing when it's requested by a non-user" in {
Post("/status?processingDisabled=true") ~> notAuthRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
}
}
"not enable processing when it's requested by a non-user" in {
Post("/status?processingDisabled=false") ~> notAuthRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
// remember, the signal is the opposite of intent
processingDisabled.get.unsafeRunSync() shouldBe false
}
}
} }
} }