diff --git a/modules/core/src/main/scala/vinyldns/core/task/TaskRepository.scala b/modules/core/src/main/scala/vinyldns/core/task/TaskRepository.scala index 1782256b3..e07ed63df 100644 --- a/modules/core/src/main/scala/vinyldns/core/task/TaskRepository.scala +++ b/modules/core/src/main/scala/vinyldns/core/task/TaskRepository.scala @@ -23,7 +23,9 @@ import scala.concurrent.duration.FiniteDuration trait TaskRepository extends Repository { - def claimTask(name: String, pollingInterval: FiniteDuration): IO[Boolean] + def claimTask(name: String, taskTimeout: FiniteDuration): IO[Boolean] def releaseTask(name: String): IO[Unit] + + def saveTask(name: String): IO[Unit] } diff --git a/modules/core/src/main/scala/vinyldns/core/task/TaskScheduler.scala b/modules/core/src/main/scala/vinyldns/core/task/TaskScheduler.scala new file mode 100644 index 000000000..d0b7cfbfd --- /dev/null +++ b/modules/core/src/main/scala/vinyldns/core/task/TaskScheduler.scala @@ -0,0 +1,103 @@ +/* + * Copyright 2018 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package vinyldns.core.task +import cats.effect._ +import cats.implicits._ +import fs2._ +import org.slf4j.LoggerFactory +import vinyldns.core.route.Monitored + +import scala.concurrent.duration.FiniteDuration + +// Interface for all Tasks that need to be run +trait Task { + // The name of the task, should be unique / constant and should not change + def name: String + + // The amount of time this task is running before it can be reclaimed / considered failed + def timeout: FiniteDuration + + // How often to attempt to run the task + def runEvery: FiniteDuration + + // Runs the task + def run(): IO[Unit] +} + +object TaskScheduler extends Monitored { + private val logger = LoggerFactory.getLogger("TaskScheduler") + + /** + * Schedules a task to be run. Will insert the task into the TaskRepository if it is not + * already present. The result is a Stream that the caller has to manage the lifecycle for. + * + * @example + * val scheduledStream = TaskScheduler.schedule(...) + * val handle = scheduledStream.compile.drain.start.unsafeRunSync() + * ... + * // once everything is done you can cancel it via the handle + * handle.cancel.unsafeRunSync() + * @return a Stream that when run will awake on the interval defined on the Task provided, and + * run the task.run() + */ + def schedule(task: Task, taskRepository: TaskRepository)( + implicit t: Timer[IO], + cs: ContextShift[IO]): Stream[IO, Unit] = { + + def claimTask(): IO[Option[Task]] = + taskRepository.claimTask(task.name, task.timeout).map { + case true => + logger.info(s"""Successfully found and claimed task; taskName="${task.name}" """) + Some(task) + case false => + logger.info(s"""No task claimed; taskName="${task.name}" """) + None + } + + // Note: IO.suspend is needed due to bug in cats effect 1.0.0 #421 + def releaseTask(maybeTask: Option[Task]): IO[Unit] = IO.suspend { + maybeTask + .map( + t => + taskRepository + .releaseTask(t.name) + .as(logger.info(s"""Released task; taskName="${task.name}" """))) + .getOrElse(IO.unit) + } + + def runTask(maybeTask: Option[Task]): IO[Unit] = maybeTask.map(_.run()).getOrElse(IO.unit) + + // Acquires a task, runs it, and makes sure it is cleaned up, swallows the error via a log + def runOnceSafely(task: Task): IO[Unit] = + monitor(s"task.${task.name}") { + claimTask().bracket(runTask)(releaseTask).handleError { error => + logger.error(s"""Unexpected error running task; taskName="${task.name}" """, error) + } + } + + // We must first schedule the task in the repository and then create our stream + // Semantics of repo.scheduleTask are idempotent, if the task already exists we are ok + // Then we run our scheduled task with awakeEvery + Stream + .eval(taskRepository.saveTask(task.name)) + .flatMap { _ => + Stream + .awakeEvery[IO](task.runEvery) + .evalMap(_ => runOnceSafely(task)) + } + } +} diff --git a/modules/core/src/test/scala/vinyldns/core/task/TaskSchedulerSpec.scala b/modules/core/src/test/scala/vinyldns/core/task/TaskSchedulerSpec.scala new file mode 100644 index 000000000..85e60c16e --- /dev/null +++ b/modules/core/src/test/scala/vinyldns/core/task/TaskSchedulerSpec.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2018 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package vinyldns.core.task +import cats.effect.{ContextShift, IO, Timer} +import org.mockito.Mockito +import org.mockito.Mockito._ +import org.scalatest.mockito.MockitoSugar +import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec} + +import scala.concurrent.duration._ + +class TaskSchedulerSpec extends WordSpec with Matchers with MockitoSugar with BeforeAndAfterEach { + + private implicit val cs: ContextShift[IO] = + IO.contextShift(scala.concurrent.ExecutionContext.global) + private implicit val timer: Timer[IO] = IO.timer(scala.concurrent.ExecutionContext.global) + + private val mockRepo = mock[TaskRepository] + + class TestTask( + val name: String, + val timeout: FiniteDuration, + val runEvery: FiniteDuration, + testResult: IO[Unit] = IO.unit) + extends Task { + def run(): IO[Unit] = testResult + } + + override def beforeEach() = Mockito.reset(mockRepo) + + "TaskScheduler" should { + "run a scheduled task" in { + val task = new TestTask("test", 5.seconds, 500.millis) + val spied = spy(task) + doReturn(IO.unit).when(mockRepo).saveTask(task.name) + doReturn(IO.pure(true)).when(mockRepo).claimTask(task.name, task.timeout) + doReturn(IO.unit).when(mockRepo).releaseTask(task.name) + + TaskScheduler.schedule(spied, mockRepo).take(1).compile.drain.unsafeRunSync() + + verify(spied).run() + verify(mockRepo).claimTask(task.name, task.timeout) + verify(mockRepo).releaseTask(task.name) + } + + "release the task even on error" in { + val task = + new TestTask("test", 5.seconds, 500.millis, IO.raiseError(new RuntimeException("fail"))) + doReturn(IO.unit).when(mockRepo).saveTask(task.name) + doReturn(IO.pure(true)).when(mockRepo).claimTask(task.name, task.timeout) + doReturn(IO.unit).when(mockRepo).releaseTask(task.name) + + TaskScheduler.schedule(task, mockRepo).take(1).compile.drain.unsafeRunSync() + verify(mockRepo).releaseTask(task.name) + } + + "fail to start if the task cannot be saved" in { + val task = new TestTask("test", 5.seconds, 500.millis) + val spied = spy(task) + doReturn(IO.raiseError(new RuntimeException("fail"))).when(mockRepo).saveTask(task.name) + + a[RuntimeException] should be thrownBy TaskScheduler + .schedule(task, mockRepo) + .take(1) + .compile + .drain + .unsafeRunSync() + verify(spied, never()).run() + } + } +} diff --git a/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlTaskRepositoryIntegrationSpec.scala b/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlTaskRepositoryIntegrationSpec.scala index f7d566b3f..6a0756afe 100644 --- a/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlTaskRepositoryIntegrationSpec.scala +++ b/modules/mysql/src/it/scala/vinyldns/mysql/repository/MySqlTaskRepositoryIntegrationSpec.scala @@ -16,25 +16,20 @@ package vinyldns.mysql.repository +import java.time.Instant + import cats.effect.IO -import org.joda.time.DateTime import org.scalatest._ -import scalikejdbc.DB +import scalikejdbc.{DB, _} import vinyldns.mysql.TestMySqlInstance import scala.concurrent.duration._ -import scalikejdbc._ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll with BeforeAndAfterEach with Matchers { private val repo = TestMySqlInstance.taskRepository.asInstanceOf[MySqlTaskRepository] private val TASK_NAME = "task_name" - private val INSERT_STATEMENT = - sql""" - |INSERT INTO task (name, in_flight, created, updated) - | VALUES ({task_name}, {in_flight}, {created}, {updated}) - """.stripMargin - private val startDateTime = DateTime.now + case class TaskInfo(inFlight: Boolean, updated: Option[Instant]) override protected def beforeEach(): Unit = clear().unsafeRunSync() @@ -46,32 +41,29 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll } } - def insertTask(inFlight: Int, created: DateTime, updated: DateTime): IO[Unit] = IO { + def ageTaskBySeconds(seconds: Long): IO[Int] = IO { DB.localTx { implicit s => - INSERT_STATEMENT - .bindByName('task_name -> TASK_NAME, - 'in_flight -> inFlight, - 'created -> created, - 'updated -> updated - ) + sql"UPDATE task SET updated = DATE_SUB(NOW(),INTERVAL {ageSeconds} SECOND)" + .bindByName('ageSeconds -> seconds) .update() .apply() } } - def getTaskInfo: IO[Option[(Boolean, DateTime)]] = IO { + def getTaskInfo(name: String): IO[TaskInfo] = IO { DB.readOnly { implicit s => - sql"SELECT in_flight, updated from task FOR UPDATE" - .map(rs => (rs.boolean(1), new DateTime(rs.timestamp(2)))) + sql"SELECT in_flight, updated from task WHERE name = {taskName}" + .bindByName('taskName -> name) + .map(rs => TaskInfo(rs.boolean(1), rs.timestampOpt(2).map(_.toInstant))) .first() - .apply() + .apply().getOrElse(throw new RuntimeException(s"TASK $name NOT FOUND")) } } "claimTask" should { - "return true if non-in-flight task exists and updated time is null" in { + "return true if non-in-flight task exists task is new" in { val f = for { - _ <- insertTask(0, startDateTime, startDateTime) + _ <- repo.saveTask(TASK_NAME) unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) } yield unclaimedTaskExists @@ -79,15 +71,18 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll } "return true if non-in-flight task exists and expiration time has elapsed" in { val f = for { - _ <- insertTask(0, startDateTime, startDateTime.minusHours(2)) - unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) + _ <- repo.saveTask(TASK_NAME) + _ <- ageTaskBySeconds(100) // Age the task by 100 seconds + unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.second) } yield unclaimedTaskExists f.unsafeRunSync() shouldBe true } "return false if in-flight task exists and expiration time has not elapsed" in { val f = for { - _ <- insertTask(1, startDateTime, startDateTime) + _ <- repo.saveTask(TASK_NAME) + _ <- repo.claimTask(TASK_NAME, 1.hour) + _ <- ageTaskBySeconds(5) // Age the task by only 5 seconds unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) } yield unclaimedTaskExists @@ -105,16 +100,55 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll "release task" should { "unset in-flight flag for task and update time" in { val f = for { - _ <- insertTask(1, startDateTime, startDateTime) + _ <- repo.saveTask(TASK_NAME) + _ <- repo.claimTask(TASK_NAME, 1.hour) + _ <- ageTaskBySeconds(2) + oldTaskInfo <- getTaskInfo(TASK_NAME) _ <- repo.releaseTask(TASK_NAME) - taskInfo <- getTaskInfo - } yield taskInfo + newTaskInfo <- getTaskInfo(TASK_NAME) + } yield (oldTaskInfo, newTaskInfo) - f.unsafeRunSync().foreach { tuple => - val (inFlight, updateTime) = tuple - inFlight shouldBe false - updateTime should not be startDateTime + val (oldTaskInfo, newTaskInfo) = f.unsafeRunSync() + + // make sure the in_flight is unset + newTaskInfo.inFlight shouldBe false + + // make sure that the updated time is later than the claimed time + oldTaskInfo.updated shouldBe defined + newTaskInfo.updated shouldBe defined + oldTaskInfo.updated.zip(newTaskInfo.updated).foreach { + case (claimTime, releaseTime) => + releaseTime should be > claimTime } } } + + "save task" should { + "insert a new task" in { + val f = for { + _ <- repo.saveTask(TASK_NAME) + taskInfo <- getTaskInfo(TASK_NAME) + } yield taskInfo + + val taskInfo = f.unsafeRunSync() + taskInfo.inFlight shouldBe false + taskInfo.updated shouldBe empty + } + + "not replace a task that is already present" in { + // schedule a task and claim it, then try to reschedule it while it is claimed (bad) + // the result should be that the task is still claimed / in_flight + val f = for { + _ <- repo.saveTask("repeat") + _ <- repo.claimTask("repeat", 5.seconds) + firstTaskInfo <- getTaskInfo("repeat") + _ <- repo.saveTask("repeat") + secondTaskInfo <- getTaskInfo("repeat") + _ <- repo.releaseTask("repeat") + } yield (firstTaskInfo, secondTaskInfo) + + val (first, second) = f.unsafeRunSync() + first shouldBe second + } + } } diff --git a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlTaskRepository.scala b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlTaskRepository.scala index 61f753a49..d6d2f6068 100644 --- a/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlTaskRepository.scala +++ b/modules/mysql/src/main/scala/vinyldns/mysql/repository/MySqlTaskRepository.scala @@ -17,7 +17,6 @@ package vinyldns.mysql.repository import cats.effect.IO -import org.joda.time.DateTime import scalikejdbc._ import vinyldns.core.task.TaskRepository @@ -32,34 +31,40 @@ class MySqlTaskRepository extends TaskRepository { * * `updated IS NULL` case is for the first run where the seeded data does not have an updated time set */ - private val CLAIM_UNCLAIMED_TASK = sql""" |UPDATE task - | SET in_flight = 1, updated = {currentTime} + | SET in_flight = 1, updated = NOW() | WHERE (in_flight = 0 | OR updated IS NULL - | OR updated < {updatedTimeComparison}) + | OR updated < DATE_SUB(NOW(),INTERVAL {timeoutSeconds} SECOND)) | AND name = {taskName}; """.stripMargin private val UNCLAIM_TASK = sql""" |UPDATE task - | SET in_flight = 0, updated = {currentTime} - | WHERE name = {name} + | SET in_flight = 0, updated = NOW() + | WHERE name = {taskName} """.stripMargin - def claimTask(name: String, pollingInterval: FiniteDuration): IO[Boolean] = + // In case multiple nodes attempt to insert task at the same time, do not overwrite + private val PUT_TASK = + sql""" + |INSERT IGNORE INTO task(name, in_flight, created, updated) + |VALUES ({taskName}, 0, NOW(), NULL) + """.stripMargin + + /** + * Note - the column in MySQL is datetime with no fractions, so the best we can do is seconds + * If taskTimeout is less than one second, this will never claim as + * FiniteDuration.toSeconds results in ZERO OL for something like 500.millis + */ + def claimTask(name: String, taskTimeout: FiniteDuration): IO[Boolean] = IO { - val pollingExpirationHours = pollingInterval.toHours * 2 - val currentTime = DateTime.now DB.localTx { implicit s => val updateResult = CLAIM_UNCLAIMED_TASK - .bindByName( - 'updatedTimeComparison -> currentTime.minusHours(pollingExpirationHours.toInt), - 'taskName -> name, - 'currentTime -> currentTime) + .bindByName('timeoutSeconds -> taskTimeout.toSeconds, 'taskName -> name) .first() .update() .apply() @@ -70,7 +75,15 @@ class MySqlTaskRepository extends TaskRepository { def releaseTask(name: String): IO[Unit] = IO { DB.localTx { implicit s => - UNCLAIM_TASK.bindByName('currentTime -> DateTime.now, 'name -> name).update().apply() + UNCLAIM_TASK.bindByName('taskName -> name).update().apply() + } + } + + // Save the task, do not overwrite if it is already there + def saveTask(name: String): IO[Unit] = IO { + DB.localTx { implicit s => + PUT_TASK.bindByName('taskName -> name).update().apply() + () } } } diff --git a/project/Dependencies.scala b/project/Dependencies.scala index a809951c6..4630e6b6f 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -67,7 +67,8 @@ object Dependencies { "com.sun.xml.bind" % "jaxb-core" % jaxbV, "com.sun.xml.bind" % "jaxb-impl" % jaxbV, "ch.qos.logback" % "logback-classic" % "1.0.7", - "io.dropwizard.metrics" % "metrics-jvm" % "3.2.2" + "io.dropwizard.metrics" % "metrics-jvm" % "3.2.2", + "co.fs2" %% "fs2-core" % "1.0.0" ) lazy val dynamoDBDependencies = Seq(