2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-08-22 10:10:12 +00:00

Merge pull request #1166 from Aravindh-Raju/aravindhr/add-auth-for-status

Add auth on POST /status
This commit is contained in:
Nicholas Spadaccino 2022-11-08 12:48:40 -05:00 committed by GitHub
commit 1c2c7b3424
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 167 additions and 43 deletions

View File

@ -17,10 +17,15 @@
package vinyldns.api.route
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.server.Route
import akka.util.Timeout
import cats.effect.IO
import fs2.concurrent.SignallingRef
import org.slf4j.{Logger, LoggerFactory}
import vinyldns.api.Interfaces.{EitherImprovements, Result, ensuring}
import vinyldns.api.config.ServerConfig
import vinyldns.api.domain.zone.NotAuthorizedError
import vinyldns.core.domain.auth.AuthPrincipal
import scala.concurrent.duration._
@ -31,40 +36,64 @@ final case class CurrentStatus(
version: String
)
trait StatusRoute extends Directives {
this: VinylDNSJsonProtocol =>
class StatusRoute(
serverConfig: ServerConfig,
val vinylDNSAuthenticator: VinylDNSAuthenticator,
val processingDisabled: SignallingRef[IO, Boolean]
) extends VinylDNSJsonProtocol
with VinylDNSDirectives[Throwable] {
implicit val timeout = Timeout(10.seconds)
def getRoutes: Route = statusRoute
def processingDisabled: SignallingRef[IO, Boolean]
implicit val timeout: Timeout = Timeout(10.seconds)
def statusRoute(color: String, version: String, keyName: String) =
def logger: Logger = LoggerFactory.getLogger(classOf[StatusRoute])
def handleErrors(e: Throwable): PartialFunction[Throwable, Route] = {
case NotAuthorizedError(msg) => complete(StatusCodes.Forbidden, msg)
}
def postStatus(isProcessingDisabled: Boolean, authPrincipal: AuthPrincipal): Result[Boolean] = {
for {
_ <- isAdmin(authPrincipal).toResult
isDisabled = isProcessingDisabled
} yield isDisabled
}
def isAdmin(authPrincipal: AuthPrincipal): Either[Throwable, Unit] =
ensuring(NotAuthorizedError(s"Not authorized. User '${authPrincipal.signedInUser.userName}' cannot make the requested change.")) {
authPrincipal.isSystemAdmin
}
val statusRoute: Route =
(get & path("status")) {
onSuccess(processingDisabled.get.unsafeToFuture()) { isProcessingDisabled =>
complete(
StatusCodes.OK,
CurrentStatus(
isProcessingDisabled,
color,
keyName,
version
serverConfig.color,
serverConfig.keyName,
serverConfig.version
)
)
}
} ~
(post & path("status")) {
parameters("processingDisabled".as[Boolean]) { isProcessingDisabled =>
authenticateAndExecute(postStatus(isProcessingDisabled, _)){ isProcessingDisabled =>
onSuccess(processingDisabled.set(isProcessingDisabled).unsafeToFuture()) {
complete(
StatusCodes.OK,
CurrentStatus(
isProcessingDisabled,
color,
keyName,
version
serverConfig.color,
serverConfig.keyName,
serverConfig.version
)
)
}
}
}
}
}

View File

@ -69,7 +69,6 @@ class VinylDNSService(
) extends PingRoute
with HealthCheckRoute
with BlueGreenRoute
with StatusRoute
with PrometheusRoute
with VinylDNSJsonProtocol
with RequestLogging {
@ -97,27 +96,29 @@ class VinylDNSService(
vinylDNSAuthenticator,
vinyldnsConfig.manualReviewConfig
).getRoutes
val statusRoute: Route =
new StatusRoute(
vinyldnsConfig.serverConfig,
vinylDNSAuthenticator,
processingDisabled
).getRoutes
val unloggedUris = Seq(
Uri.Path("/health"),
Uri.Path("/color"),
Uri.Path("/ping"),
Uri.Path("/status"),
Uri.Path("/metrics/prometheus")
)
val unloggedRoutes: Route = healthCheckRoute ~ pingRoute ~ colorRoute(
vinyldnsConfig.serverConfig.color
) ~ statusRoute(
vinyldnsConfig.serverConfig.color,
vinyldnsConfig.serverConfig.version,
vinyldnsConfig.serverConfig.keyName
) ~ prometheusRoute
val allRoutes: Route = unloggedRoutes ~
batchChangeRoute ~
zoneRoute ~
recordSetRoute ~
membershipRoute
membershipRoute ~
statusRoute
val vinyldnsRoutes: Route = logRequestResult(requestLogger(unloggedUris))(allRoutes)

View File

@ -18,6 +18,43 @@ def test_get_status_success(shared_zone_test_context):
assert_that(result["version"], not_none())
def test_post_status_fails_for_non_admin_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint fails for non-admin users
"""
client = shared_zone_test_context.ok_vinyldns_client
result = client.post_status(True)
assert_that(result, is_("Not authorized. User 'ok' cannot make the requested change."))
def test_post_status_fails_for_non_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint fails for non-users with fake access and secret key
"""
client = shared_zone_test_context.non_user_client
result = client.post_status(True)
assert_that(result, is_("Authentication Failed: Account with accessKey not-exist-key specified was not found"))
def test_post_status_pass_for_admin_users(shared_zone_test_context):
"""
Tests that the post request to status endpoint pass for admin users
"""
client = shared_zone_test_context.support_user_client
client.post_status(True)
status = client.get_status()
assert_that(status["processingDisabled"], is_(True))
client.post_status(False)
status = client.get_status()
assert_that(status["processingDisabled"], is_(False))
@pytest.mark.serial
@pytest.mark.skip_production
def test_toggle_processing(shared_zone_test_context):
@ -25,15 +62,16 @@ def test_toggle_processing(shared_zone_test_context):
Test that updating a zone when processing is disabled does not happen
"""
client = shared_zone_test_context.ok_vinyldns_client
admin_client = shared_zone_test_context.support_user_client
ok_zone = copy.deepcopy(shared_zone_test_context.ok_zone)
# disable processing
client.post_status(True)
admin_client.post_status(True)
status = client.get_status()
assert_that(status["processingDisabled"], is_(True))
client.post_status(False)
admin_client.post_status(False)
status = client.get_status()
assert_that(status["processingDisabled"], is_(False))

View File

@ -31,8 +31,9 @@ class SharedZoneTestContext(object):
self.unassociated_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "listGroupAccessKey", "listGroupSecretKey")
self.test_user_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "testUserAccessKey", "testUserSecretKey")
self.history_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "history-key", "history-secret")
self.non_user_client = VinylDNSClient(VinylDNSTestContext.vinyldns_url, "not-exist-key", "not-exist-secret")
self.clients = [self.ok_vinyldns_client, self.dummy_vinyldns_client, self.shared_zone_vinyldns_client, self.support_user_client,
self.unassociated_client, self.test_user_client, self.history_client]
self.unassociated_client, self.test_user_client, self.history_client, self.non_user_client]
self.list_zones = ListZonesTestContext(partition_id)
self.list_zones_client = self.list_zones.client
self.list_records_context = ListRecordSetsTestContext(partition_id)

View File

@ -16,9 +16,11 @@
package vinyldns.api
import cats.effect.{ContextShift, IO}
import com.comcast.ip4s.IpAddress
import fs2.concurrent.SignallingRef
import org.joda.time.DateTime
import vinyldns.api.config.{ZoneAuthConfigs, BatchChangeConfig, DottedHostsConfig, HighValueDomainConfig, LimitsConfig, ManualReviewConfig, ScheduledChangesConfig}
import vinyldns.api.config.{ZoneAuthConfigs, DottedHostsConfig, BatchChangeConfig, HighValueDomainConfig, LimitsConfig, ManualReviewConfig, ScheduledChangesConfig, ServerConfig}
import vinyldns.api.domain.batch.V6DiscoveryNibbleBoundaries
import vinyldns.core.domain.record._
import vinyldns.core.domain.zone._
@ -27,6 +29,12 @@ import scala.util.matching.Regex
trait VinylDNSTestHelpers {
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
val processingDisabled: SignallingRef[IO, Boolean] =
fs2.concurrent.SignallingRef[IO, Boolean](false).unsafeRunSync()
val highValueDomainRegexList: List[Regex] = List(new Regex("high-value-domain.*"))
val highValueDomainIpList: List[IpAddress] =
(IpAddress("192.0.2.252") ++ IpAddress("192.0.2.253") ++ IpAddress(
@ -75,6 +83,9 @@ trait VinylDNSTestHelpers {
val testLimitConfig: LimitsConfig =
LimitsConfig(100,100,1000,1500,100,100,100)
val testServerConfig: ServerConfig =
ServerConfig(100, 100, 100, 100, true, approvedNameServers, "blue", "unset", "vinyldns.", false, true, true)
val batchChangeConfig: BatchChangeConfig =
BatchChangeConfig(batchChangeLimit, sharedApprovedTypes, v6DiscoveryNibbleBoundaries)

View File

@ -18,35 +18,49 @@ package vinyldns.api.route
import akka.actor.ActorSystem
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.effect.{ContextShift, IO}
import fs2.concurrent.SignallingRef
import org.scalatest._
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import org.scalatestplus.mockito.MockitoSugar
import vinyldns.api.VinylDNSTestHelpers
import vinyldns.api.VinylDNSTestHelpers.processingDisabled
import vinyldns.core.TestMembershipData.{notAuth, okAuth, superUserAuth}
class StatusRoutingSpec
extends AnyWordSpec
with ScalatestRouteTest
with StatusRoute
with OneInstancePerTest
with VinylDNSJsonProtocol
with BeforeAndAfterEach
with MockitoSugar
with Matchers {
val statusRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(okAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
val notAuthRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(notAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
val adminUserRoute: Route =
new StatusRoute(
VinylDNSTestHelpers.testServerConfig,
new TestVinylDNSAuthenticator(superUserAuth),
VinylDNSTestHelpers.processingDisabled
).getRoutes
def actorRefFactory: ActorSystem = system
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
val processingDisabled: SignallingRef[IO, Boolean] =
fs2.concurrent.SignallingRef[IO, Boolean](false).unsafeRunSync()
"GET /status" should {
"return the current status of true" in {
Get("/status") ~> statusRoute("blue", "unset", "vinyldns.") ~> check {
Get("/status") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe false
@ -58,16 +72,16 @@ class StatusRoutingSpec
}
"POST /status" should {
"disable processing" in {
Post("/status?processingDisabled=true") ~> statusRoute("blue", "unset", "vinyldns.") ~> check {
"disable processing when it's requested by admin user" in {
Post("/status?processingDisabled=true") ~> adminUserRoute ~> check {
response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe true
}
}
"enable processing" in {
Post("/status?processingDisabled=false") ~> statusRoute("blue", "unset", "vinyldns.") ~> check {
"enable processing when it's requested by admin user" in {
Post("/status?processingDisabled=false") ~> adminUserRoute ~> check {
response.status shouldBe StatusCodes.OK
val resultStatus = responseAs[CurrentStatus]
resultStatus.processingDisabled shouldBe false
@ -76,5 +90,35 @@ class StatusRoutingSpec
processingDisabled.get.unsafeRunSync() shouldBe false
}
}
"not disable processing when it's requested by non-admin user" in {
Post("/status?processingDisabled=true") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
}
}
"not enable processing when it's requested by non-admin user" in {
Post("/status?processingDisabled=false") ~> statusRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
// remember, the signal is the opposite of intent
processingDisabled.get.unsafeRunSync() shouldBe false
}
}
"not disable processing when it's requested by a non-user" in {
Post("/status?processingDisabled=true") ~> notAuthRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
}
}
"not enable processing when it's requested by a non-user" in {
Post("/status?processingDisabled=false") ~> notAuthRoute ~> check {
response.status shouldBe StatusCodes.Forbidden
// remember, the signal is the opposite of intent
processingDisabled.get.unsafeRunSync() shouldBe false
}
}
}
}