diff --git a/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneService.scala b/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneService.scala index 16e64c8d8..8d4776786 100644 --- a/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneService.scala +++ b/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneService.scala @@ -34,7 +34,6 @@ import com.cronutils.parser.CronParser import com.cronutils.model.CronType import org.slf4j.LoggerFactory import vinyldns.api.domain.membership.MembershipService -import vinyldns.core.Messages import org.json4s._ import org.json4s.jackson.JsonMethods._ import com.fasterxml.jackson.databind.ObjectMapper @@ -45,8 +44,6 @@ import scala.util.Try import scala.jdk.CollectionConverters._ import java.net.URLEncoder import java.nio.charset.StandardCharsets - -import java.io.{ByteArrayInputStream, InputStream, OutputStream} import java.net.{HttpURLConnection, URL} import scala.io.Source @@ -169,6 +166,35 @@ class ZoneService( new URL(apiUrl).openConnection().asInstanceOf[HttpURLConnection] } + private def schemaValidationResult( + providerConfig: DnsProviderConfig, + operation: String, + params: Map[String, JValue] + ): Result[Unit] = providerConfig.schemas.get(operation) match { + case Some(schema) => JsonSchemaValidator.validate(schema, params).toResult + case None => result(()) + } + + private def existenceCheck(operation: String, zoneName: String): Result[Unit] = operation match { + case "create-zone" => generateZoneDoesNotExist(zoneName).toResult + case "delete-zone" => generateZoneExists(zoneName).toResult + case "update-zone" => generateZoneExists(zoneName).toResult + case _ => Left(InvalidRequest(s"Unsupported operation: $operation")).toResult + } + + private def executeRepoAction( + operation: String, + request: ZoneGenerationInput, + response: ZoneGenerationResponse + ): Result[Unit] = { + val zoneToGenerate = GenerateZone(request) + operation match { + case "delete-zone" => generateZoneRepository.deleteTx(zoneToGenerate).toResult + case "create-zone" => generateZoneRepository.save(zoneToGenerate.copy(response = Some(response))).toResult + case "update-zone" => generateZoneRepository.save(zoneToGenerate.copy(response = Some(response))).toResult + case _ => Left(InvalidRequest(s"Unsupported operation: $operation")).toResult + } + } def handleGenerateZoneRequest( operation: String, request: ZoneGenerationInput, @@ -178,45 +204,32 @@ class ZoneService( // Validate input providerConfig <- validateProvider(request.provider, dnsProviderApiConnection.providers).toResult _ <- validateZoneName(request.zoneName).toResult - - // JSON Schema validation for providerParams - _ <- JsonSchemaValidator - .validate( - providerConfig.schemas(operation), - request.providerParams - ).toResult + _ <- schemaValidationResult(providerConfig, operation, request.providerParams) _ <- logger.info(s"Request providerParams: ${request.providerParams}").toResult - // Build JSON request - generateZoneRequestJson = buildGenerateZoneRequestJson( - providerConfig.requestTemplates(operation), - request - ) - - // Authorization checks - _ <- canChangeZone(auth, request.zoneName, request.groupId).toResult - _ <- generateZoneDoesNotExist(request.zoneName).toResult - + // Build request and endpoint endpoint = buildGenerateZoneEndpoint(providerConfig.endpoints(operation), request) + requestJsonOpt = buildGenerateZoneRequestJson(providerConfig.requestTemplates.get(operation), request) - _ = logger.info(s"Request: provider=${request.provider}, path=${endpoint}, request=$generateZoneRequestJson").toResult + // Authorization and existence checks + _ <- canChangeZone(auth, request.zoneName, request.groupId).toResult + _ <- existenceCheck(operation, request.zoneName) + + // Send request + _ <- logger.info(s"Request: provider=${request.provider}, path=$endpoint, request=$requestJsonOpt").toResult dnsProviderConn <- createConnection(endpoint).toResult - dnsConnResponse <- createDnsZoneService( - endpoint, - providerConfig.apiKey, - generateZoneRequestJson, - dnsProviderConn - ).toResult + dnsConnResponse <- createDnsZoneService(providerConfig.apiKey, operation, requestJsonOpt, dnsProviderConn).toResult // Process response responseCode = dnsConnResponse.getResponseCode + _ <- logger.info(s"response code: $responseCode").toResult inputStream = if (responseCode >= 400) dnsConnResponse.getErrorStream else dnsConnResponse.getInputStream responseMessage: String = Source.fromInputStream(inputStream, "UTF-8").mkString _ <- isValidGenerateZoneConn(responseCode, responseMessage).toResult - // Parse response - responseJson = parse(responseMessage) + // Only parse JSON if the response is non-empty + responseJson = if (responseMessage.nonEmpty) parse(responseMessage) else JNothing // Create response object zoneGenerateResponse = ZoneGenerationResponse( @@ -225,32 +238,28 @@ class ZoneService( status = dnsConnResponse.getResponseMessage, message = responseJson ) + _ <- logger.info(s"response: $zoneGenerateResponse").toResult - // Save to repository - zoneToGenerate = GenerateZone(request) - _ <- generateZoneRepository.save(zoneToGenerate.copy(response = Some(zoneGenerateResponse))).toResult[GenerateZone] + _ <- executeRepoAction(operation, request, zoneGenerateResponse) - } yield { - // Cleanup resources - Option(inputStream).foreach(_.close()) - Option(dnsConnResponse).foreach(_.disconnect()) - zoneGenerateResponse - } + } yield zoneGenerateResponse } // Build a Generate Zone JSON request using template engine private def buildGenerateZoneRequestJson( - requestTemplate: String, - zoneGenerationInput: ZoneGenerationInput - ): String = { + maybeRequestTemplate: Option[String], + zoneGenerationInput: ZoneGenerationInput + ): Option[String] = { val baseParams = Map( "zoneName" -> JString(zoneGenerationInput.zoneName), "provider" -> JString(zoneGenerationInput.provider), - "groupId" -> JString(zoneGenerationInput.groupId), - "email" -> JString(zoneGenerationInput.email) + "groupId" -> JString(zoneGenerationInput.groupId), + "email" -> JString(zoneGenerationInput.email) ) - TemplateEngine.renderTemplate(requestTemplate, baseParams ++ zoneGenerationInput.providerParams) + maybeRequestTemplate.map { requestTemplate => + TemplateEngine.renderTemplate(requestTemplate, baseParams ++ zoneGenerationInput.providerParams) + } } private def buildGenerateZoneEndpoint( @@ -354,36 +363,41 @@ class ZoneService( } } - def createDnsZoneService(dnsApiUrl: String, dnsApiKey: String, request: String, connection: HttpURLConnection): Either[Throwable, HttpURLConnection] = - { + def createDnsZoneService( + dnsApiKey: String, + operation: String, + request: Option[String], + connection: HttpURLConnection + ): Either[Throwable, HttpURLConnection] = { try { - //val connection = new URL(dnsApiUrl).openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod("POST") + // Map operation to HTTP method + val method = operation match { + case "create-zone" => "POST" + case "update-zone" => "PUT" + case "delete-zone" => "DELETE" + case other => throw new IllegalArgumentException(s"Unsupported operation: $other") + } + + connection.setRequestMethod(method) connection.setRequestProperty("Content-Type", "application/json") connection.setRequestProperty("X-API-Key", dnsApiKey) - connection.setDoOutput(true) - - val outputStream: OutputStream = connection.getOutputStream - outputStream.write(request.getBytes("UTF-8")) - outputStream.close() + // Only send a body if the HTTP method and request are appropriate + val methodsWithBody = Set("POST", "PUT", "PATCH") + if (methodsWithBody.contains(method) && request.isDefined) { + connection.setDoOutput(true) + val outputStream = connection.getOutputStream + try { + outputStream.write(request.get.getBytes("UTF-8")) + } finally { + outputStream.close() + } + } Right(connection) } catch { case e: Exception => - val errorConnection = new HttpURLConnection(new URL(dnsApiUrl)) { - private val errorJson = Messages.dnsProviderConnRefusedMessage(e, dnsApiUrl) - private val errorBytes = errorJson.getBytes("UTF-8") - private val errorByteStream = new ByteArrayInputStream(errorBytes) - - override def disconnect(): Unit = {} - override def usingProxy(): Boolean = false - override def connect(): Unit = {} - - override def getResponseCode: Int = 500 - override def getErrorStream: InputStream = errorByteStream - } - Right(errorConnection) + Left(e) } } @@ -757,6 +771,20 @@ class ZoneService( } } + private def generateZoneExists(zoneName: String): Either[Throwable, Unit] = { + val existingZoneOpt: Option[GenerateZone] = + generateZoneRepository.getGenerateZoneByName(zoneName).unsafeRunSync() + + existingZoneOpt match { + case Some(_) => + Right(()) + case None => + Left(ZoneNotFoundError( + s"Zone with name $zoneName does not exist." + )) + } + } + def canScheduleZoneSync(auth: AuthPrincipal): Either[Throwable, Unit] = ensuring( NotAuthorizedError(s"User '${auth.signedInUser.userName}' is not authorized to schedule zone sync in this zone.") diff --git a/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneValidations.scala b/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneValidations.scala index d39ffce90..eb3597470 100644 --- a/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneValidations.scala +++ b/modules/api/src/main/scala/vinyldns/api/domain/zone/ZoneValidations.scala @@ -24,7 +24,6 @@ import vinyldns.api.Interfaces.ensuring import vinyldns.core.domain.membership.User import vinyldns.core.domain.record.RecordType import vinyldns.core.domain.zone.{ACLRule, Zone, ZoneACL, DnsProviderConfig} -import org.json4s._ import scala.util.{Failure, Success, Try} @@ -105,17 +104,6 @@ class ZoneValidations(syncDelayMillis: Int) { case None => Left(InvalidRequest(s"Unsupported DNS provider: $provider")) } - def validateRequiredFields( - requiredFields: List[String], - providedFields: Map[String, JValue], - providerName: String - ): Either[Throwable, Unit] = { - val missing = requiredFields.filterNot(providedFields.contains) - ensuring(InvalidRequest( - s"Missing required fields for $providerName: ${missing.mkString(", ")}" - ))(missing.isEmpty) - } - def validateZoneName(zoneName: String): Either[Throwable, Unit] = ensuring(InvalidRequest(s"Invalid zone name: $zoneName")) { zoneName.matches("""^[a-zA-Z0-9.-]+\.$""") diff --git a/modules/core/src/main/scala/vinyldns/core/domain/zone/GenerateZoneRepository.scala b/modules/core/src/main/scala/vinyldns/core/domain/zone/GenerateZoneRepository.scala index 7ee745d79..dcb34fd11 100644 --- a/modules/core/src/main/scala/vinyldns/core/domain/zone/GenerateZoneRepository.scala +++ b/modules/core/src/main/scala/vinyldns/core/domain/zone/GenerateZoneRepository.scala @@ -26,6 +26,8 @@ trait GenerateZoneRepository extends Repository { def getGenerateZoneByName(zoneName: String): IO[Option[GenerateZone]] + def deleteTx(generateZone: GenerateZone): IO[Unit] + def listGenerateZones( authPrincipal: AuthPrincipal, zoneNameFilter: Option[String] = None, diff --git a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlGenerateZoneRepository.scala b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlGenerateZoneRepository.scala index a3b83e937..f2976d52f 100644 --- a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlGenerateZoneRepository.scala +++ b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlGenerateZoneRepository.scala @@ -82,8 +82,7 @@ class MySqlGenerateZoneRepository extends GenerateZoneRepository with ProtobufCo .update() .apply() } - generateZone - + generateZone } }} @@ -97,7 +96,7 @@ class MySqlGenerateZoneRepository extends GenerateZoneRepository with ProtobufCo fromPB(VinylDNSProto.GenerateZone.parseFrom(res.bytes(columnIndex))) } - def deleteTx(generateZone: GenerateZone): IO[GenerateZone] = + def deleteTx(generateZone: GenerateZone): IO[Unit] = monitor("repo.ZoneJDBC.generateZoneDelete") { IO { DB.localTx { implicit s =>