2
0
mirror of https://github.com/VinylDNS/vinyldns synced 2025-08-22 10:10:12 +00:00

Generic Task Scheduler (#717)

Creates a more general task scheduler.  The existing user sync process had some half generic pieces, and other pieces that were tightly coupled to the user sync process.

This is the first step at making a general purpose task scheduler.  This has been proven out in the implementation of the user sync process in #718 

1. `TaskRepository` - renamed `pollingInterval` to `taskTimeout` as the value is similar to `visbilityTimeout` in SQS

2. `Task` - is an interface that needs to be implemented by future tasks.   `name` is the unique name of the task; `timeout` is how long to wait to consider the last claim expired; `runEvery` is how often to attempt to run the task; `run()` is the function that actually executes the task itself.

3. `TaskScheduler` - this is the logic of scheduling.  It embodies the logic of a) saving the task b) claiming the task c) running the task and d) releasing the task.  It uses `IO.bracket` to make sure the finalizer `releaseTask` is called no matter what the result is of running the task.  It uses `fs2.Stream.awakeEvery` for polling.  The expectation is that the caller will acquire the stream and do an `Stream.compile.drain.start` to kick it off running.  It can be cancelled using the `Fiber` returned from `Stream.compile.drain.start`
This commit is contained in:
Paul Cleary 2019-07-01 13:53:00 -04:00 committed by GitHub
parent fa17f4ceab
commit 933614ed37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 286 additions and 48 deletions

View File

@ -23,7 +23,9 @@ import scala.concurrent.duration.FiniteDuration
trait TaskRepository extends Repository { trait TaskRepository extends Repository {
def claimTask(name: String, pollingInterval: FiniteDuration): IO[Boolean] def claimTask(name: String, taskTimeout: FiniteDuration): IO[Boolean]
def releaseTask(name: String): IO[Unit] def releaseTask(name: String): IO[Unit]
def saveTask(name: String): IO[Unit]
} }

View File

@ -0,0 +1,103 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.task
import cats.effect._
import cats.implicits._
import fs2._
import org.slf4j.LoggerFactory
import vinyldns.core.route.Monitored
import scala.concurrent.duration.FiniteDuration
// Interface for all Tasks that need to be run
trait Task {
// The name of the task, should be unique / constant and should not change
def name: String
// The amount of time this task is running before it can be reclaimed / considered failed
def timeout: FiniteDuration
// How often to attempt to run the task
def runEvery: FiniteDuration
// Runs the task
def run(): IO[Unit]
}
object TaskScheduler extends Monitored {
private val logger = LoggerFactory.getLogger("TaskScheduler")
/**
* Schedules a task to be run. Will insert the task into the TaskRepository if it is not
* already present. The result is a Stream that the caller has to manage the lifecycle for.
*
* @example
* val scheduledStream = TaskScheduler.schedule(...)
* val handle = scheduledStream.compile.drain.start.unsafeRunSync()
* ...
* // once everything is done you can cancel it via the handle
* handle.cancel.unsafeRunSync()
* @return a Stream that when run will awake on the interval defined on the Task provided, and
* run the task.run()
*/
def schedule(task: Task, taskRepository: TaskRepository)(
implicit t: Timer[IO],
cs: ContextShift[IO]): Stream[IO, Unit] = {
def claimTask(): IO[Option[Task]] =
taskRepository.claimTask(task.name, task.timeout).map {
case true =>
logger.info(s"""Successfully found and claimed task; taskName="${task.name}" """)
Some(task)
case false =>
logger.info(s"""No task claimed; taskName="${task.name}" """)
None
}
// Note: IO.suspend is needed due to bug in cats effect 1.0.0 #421
def releaseTask(maybeTask: Option[Task]): IO[Unit] = IO.suspend {
maybeTask
.map(
t =>
taskRepository
.releaseTask(t.name)
.as(logger.info(s"""Released task; taskName="${task.name}" """)))
.getOrElse(IO.unit)
}
def runTask(maybeTask: Option[Task]): IO[Unit] = maybeTask.map(_.run()).getOrElse(IO.unit)
// Acquires a task, runs it, and makes sure it is cleaned up, swallows the error via a log
def runOnceSafely(task: Task): IO[Unit] =
monitor(s"task.${task.name}") {
claimTask().bracket(runTask)(releaseTask).handleError { error =>
logger.error(s"""Unexpected error running task; taskName="${task.name}" """, error)
}
}
// We must first schedule the task in the repository and then create our stream
// Semantics of repo.scheduleTask are idempotent, if the task already exists we are ok
// Then we run our scheduled task with awakeEvery
Stream
.eval(taskRepository.saveTask(task.name))
.flatMap { _ =>
Stream
.awakeEvery[IO](task.runEvery)
.evalMap(_ => runOnceSafely(task))
}
}
}

View File

@ -0,0 +1,85 @@
/*
* Copyright 2018 Comcast Cable Communications Management, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package vinyldns.core.task
import cats.effect.{ContextShift, IO, Timer}
import org.mockito.Mockito
import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{BeforeAndAfterEach, Matchers, WordSpec}
import scala.concurrent.duration._
class TaskSchedulerSpec extends WordSpec with Matchers with MockitoSugar with BeforeAndAfterEach {
private implicit val cs: ContextShift[IO] =
IO.contextShift(scala.concurrent.ExecutionContext.global)
private implicit val timer: Timer[IO] = IO.timer(scala.concurrent.ExecutionContext.global)
private val mockRepo = mock[TaskRepository]
class TestTask(
val name: String,
val timeout: FiniteDuration,
val runEvery: FiniteDuration,
testResult: IO[Unit] = IO.unit)
extends Task {
def run(): IO[Unit] = testResult
}
override def beforeEach() = Mockito.reset(mockRepo)
"TaskScheduler" should {
"run a scheduled task" in {
val task = new TestTask("test", 5.seconds, 500.millis)
val spied = spy(task)
doReturn(IO.unit).when(mockRepo).saveTask(task.name)
doReturn(IO.pure(true)).when(mockRepo).claimTask(task.name, task.timeout)
doReturn(IO.unit).when(mockRepo).releaseTask(task.name)
TaskScheduler.schedule(spied, mockRepo).take(1).compile.drain.unsafeRunSync()
verify(spied).run()
verify(mockRepo).claimTask(task.name, task.timeout)
verify(mockRepo).releaseTask(task.name)
}
"release the task even on error" in {
val task =
new TestTask("test", 5.seconds, 500.millis, IO.raiseError(new RuntimeException("fail")))
doReturn(IO.unit).when(mockRepo).saveTask(task.name)
doReturn(IO.pure(true)).when(mockRepo).claimTask(task.name, task.timeout)
doReturn(IO.unit).when(mockRepo).releaseTask(task.name)
TaskScheduler.schedule(task, mockRepo).take(1).compile.drain.unsafeRunSync()
verify(mockRepo).releaseTask(task.name)
}
"fail to start if the task cannot be saved" in {
val task = new TestTask("test", 5.seconds, 500.millis)
val spied = spy(task)
doReturn(IO.raiseError(new RuntimeException("fail"))).when(mockRepo).saveTask(task.name)
a[RuntimeException] should be thrownBy TaskScheduler
.schedule(task, mockRepo)
.take(1)
.compile
.drain
.unsafeRunSync()
verify(spied, never()).run()
}
}
}

View File

@ -16,25 +16,20 @@
package vinyldns.mysql.repository package vinyldns.mysql.repository
import java.time.Instant
import cats.effect.IO import cats.effect.IO
import org.joda.time.DateTime
import org.scalatest._ import org.scalatest._
import scalikejdbc.DB import scalikejdbc.{DB, _}
import vinyldns.mysql.TestMySqlInstance import vinyldns.mysql.TestMySqlInstance
import scala.concurrent.duration._ import scala.concurrent.duration._
import scalikejdbc._
class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll with BeforeAndAfterEach with Matchers { class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll with BeforeAndAfterEach with Matchers {
private val repo = TestMySqlInstance.taskRepository.asInstanceOf[MySqlTaskRepository] private val repo = TestMySqlInstance.taskRepository.asInstanceOf[MySqlTaskRepository]
private val TASK_NAME = "task_name" private val TASK_NAME = "task_name"
private val INSERT_STATEMENT =
sql"""
|INSERT INTO task (name, in_flight, created, updated)
| VALUES ({task_name}, {in_flight}, {created}, {updated})
""".stripMargin
private val startDateTime = DateTime.now case class TaskInfo(inFlight: Boolean, updated: Option[Instant])
override protected def beforeEach(): Unit = clear().unsafeRunSync() override protected def beforeEach(): Unit = clear().unsafeRunSync()
@ -46,32 +41,29 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll
} }
} }
def insertTask(inFlight: Int, created: DateTime, updated: DateTime): IO[Unit] = IO { def ageTaskBySeconds(seconds: Long): IO[Int] = IO {
DB.localTx { implicit s => DB.localTx { implicit s =>
INSERT_STATEMENT sql"UPDATE task SET updated = DATE_SUB(NOW(),INTERVAL {ageSeconds} SECOND)"
.bindByName('task_name -> TASK_NAME, .bindByName('ageSeconds -> seconds)
'in_flight -> inFlight,
'created -> created,
'updated -> updated
)
.update() .update()
.apply() .apply()
} }
} }
def getTaskInfo: IO[Option[(Boolean, DateTime)]] = IO { def getTaskInfo(name: String): IO[TaskInfo] = IO {
DB.readOnly { implicit s => DB.readOnly { implicit s =>
sql"SELECT in_flight, updated from task FOR UPDATE" sql"SELECT in_flight, updated from task WHERE name = {taskName}"
.map(rs => (rs.boolean(1), new DateTime(rs.timestamp(2)))) .bindByName('taskName -> name)
.map(rs => TaskInfo(rs.boolean(1), rs.timestampOpt(2).map(_.toInstant)))
.first() .first()
.apply() .apply().getOrElse(throw new RuntimeException(s"TASK $name NOT FOUND"))
} }
} }
"claimTask" should { "claimTask" should {
"return true if non-in-flight task exists and updated time is null" in { "return true if non-in-flight task exists task is new" in {
val f = for { val f = for {
_ <- insertTask(0, startDateTime, startDateTime) _ <- repo.saveTask(TASK_NAME)
unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour)
} yield unclaimedTaskExists } yield unclaimedTaskExists
@ -79,15 +71,18 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll
} }
"return true if non-in-flight task exists and expiration time has elapsed" in { "return true if non-in-flight task exists and expiration time has elapsed" in {
val f = for { val f = for {
_ <- insertTask(0, startDateTime, startDateTime.minusHours(2)) _ <- repo.saveTask(TASK_NAME)
unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) _ <- ageTaskBySeconds(100) // Age the task by 100 seconds
unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.second)
} yield unclaimedTaskExists } yield unclaimedTaskExists
f.unsafeRunSync() shouldBe true f.unsafeRunSync() shouldBe true
} }
"return false if in-flight task exists and expiration time has not elapsed" in { "return false if in-flight task exists and expiration time has not elapsed" in {
val f = for { val f = for {
_ <- insertTask(1, startDateTime, startDateTime) _ <- repo.saveTask(TASK_NAME)
_ <- repo.claimTask(TASK_NAME, 1.hour)
_ <- ageTaskBySeconds(5) // Age the task by only 5 seconds
unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour) unclaimedTaskExists <- repo.claimTask(TASK_NAME, 1.hour)
} yield unclaimedTaskExists } yield unclaimedTaskExists
@ -105,16 +100,55 @@ class MySqlTaskRepositoryIntegrationSpec extends WordSpec with BeforeAndAfterAll
"release task" should { "release task" should {
"unset in-flight flag for task and update time" in { "unset in-flight flag for task and update time" in {
val f = for { val f = for {
_ <- insertTask(1, startDateTime, startDateTime) _ <- repo.saveTask(TASK_NAME)
_ <- repo.claimTask(TASK_NAME, 1.hour)
_ <- ageTaskBySeconds(2)
oldTaskInfo <- getTaskInfo(TASK_NAME)
_ <- repo.releaseTask(TASK_NAME) _ <- repo.releaseTask(TASK_NAME)
taskInfo <- getTaskInfo newTaskInfo <- getTaskInfo(TASK_NAME)
} yield taskInfo } yield (oldTaskInfo, newTaskInfo)
f.unsafeRunSync().foreach { tuple => val (oldTaskInfo, newTaskInfo) = f.unsafeRunSync()
val (inFlight, updateTime) = tuple
inFlight shouldBe false // make sure the in_flight is unset
updateTime should not be startDateTime newTaskInfo.inFlight shouldBe false
// make sure that the updated time is later than the claimed time
oldTaskInfo.updated shouldBe defined
newTaskInfo.updated shouldBe defined
oldTaskInfo.updated.zip(newTaskInfo.updated).foreach {
case (claimTime, releaseTime) =>
releaseTime should be > claimTime
} }
} }
} }
"save task" should {
"insert a new task" in {
val f = for {
_ <- repo.saveTask(TASK_NAME)
taskInfo <- getTaskInfo(TASK_NAME)
} yield taskInfo
val taskInfo = f.unsafeRunSync()
taskInfo.inFlight shouldBe false
taskInfo.updated shouldBe empty
}
"not replace a task that is already present" in {
// schedule a task and claim it, then try to reschedule it while it is claimed (bad)
// the result should be that the task is still claimed / in_flight
val f = for {
_ <- repo.saveTask("repeat")
_ <- repo.claimTask("repeat", 5.seconds)
firstTaskInfo <- getTaskInfo("repeat")
_ <- repo.saveTask("repeat")
secondTaskInfo <- getTaskInfo("repeat")
_ <- repo.releaseTask("repeat")
} yield (firstTaskInfo, secondTaskInfo)
val (first, second) = f.unsafeRunSync()
first shouldBe second
}
}
} }

View File

@ -17,7 +17,6 @@
package vinyldns.mysql.repository package vinyldns.mysql.repository
import cats.effect.IO import cats.effect.IO
import org.joda.time.DateTime
import scalikejdbc._ import scalikejdbc._
import vinyldns.core.task.TaskRepository import vinyldns.core.task.TaskRepository
@ -32,34 +31,40 @@ class MySqlTaskRepository extends TaskRepository {
* *
* `updated IS NULL` case is for the first run where the seeded data does not have an updated time set * `updated IS NULL` case is for the first run where the seeded data does not have an updated time set
*/ */
private val CLAIM_UNCLAIMED_TASK = private val CLAIM_UNCLAIMED_TASK =
sql""" sql"""
|UPDATE task |UPDATE task
| SET in_flight = 1, updated = {currentTime} | SET in_flight = 1, updated = NOW()
| WHERE (in_flight = 0 | WHERE (in_flight = 0
| OR updated IS NULL | OR updated IS NULL
| OR updated < {updatedTimeComparison}) | OR updated < DATE_SUB(NOW(),INTERVAL {timeoutSeconds} SECOND))
| AND name = {taskName}; | AND name = {taskName};
""".stripMargin """.stripMargin
private val UNCLAIM_TASK = private val UNCLAIM_TASK =
sql""" sql"""
|UPDATE task |UPDATE task
| SET in_flight = 0, updated = {currentTime} | SET in_flight = 0, updated = NOW()
| WHERE name = {name} | WHERE name = {taskName}
""".stripMargin """.stripMargin
def claimTask(name: String, pollingInterval: FiniteDuration): IO[Boolean] = // In case multiple nodes attempt to insert task at the same time, do not overwrite
private val PUT_TASK =
sql"""
|INSERT IGNORE INTO task(name, in_flight, created, updated)
|VALUES ({taskName}, 0, NOW(), NULL)
""".stripMargin
/**
* Note - the column in MySQL is datetime with no fractions, so the best we can do is seconds
* If taskTimeout is less than one second, this will never claim as
* FiniteDuration.toSeconds results in ZERO OL for something like 500.millis
*/
def claimTask(name: String, taskTimeout: FiniteDuration): IO[Boolean] =
IO { IO {
val pollingExpirationHours = pollingInterval.toHours * 2
val currentTime = DateTime.now
DB.localTx { implicit s => DB.localTx { implicit s =>
val updateResult = CLAIM_UNCLAIMED_TASK val updateResult = CLAIM_UNCLAIMED_TASK
.bindByName( .bindByName('timeoutSeconds -> taskTimeout.toSeconds, 'taskName -> name)
'updatedTimeComparison -> currentTime.minusHours(pollingExpirationHours.toInt),
'taskName -> name,
'currentTime -> currentTime)
.first() .first()
.update() .update()
.apply() .apply()
@ -70,7 +75,15 @@ class MySqlTaskRepository extends TaskRepository {
def releaseTask(name: String): IO[Unit] = IO { def releaseTask(name: String): IO[Unit] = IO {
DB.localTx { implicit s => DB.localTx { implicit s =>
UNCLAIM_TASK.bindByName('currentTime -> DateTime.now, 'name -> name).update().apply() UNCLAIM_TASK.bindByName('taskName -> name).update().apply()
}
}
// Save the task, do not overwrite if it is already there
def saveTask(name: String): IO[Unit] = IO {
DB.localTx { implicit s =>
PUT_TASK.bindByName('taskName -> name).update().apply()
()
} }
} }
} }

View File

@ -67,7 +67,8 @@ object Dependencies {
"com.sun.xml.bind" % "jaxb-core" % jaxbV, "com.sun.xml.bind" % "jaxb-core" % jaxbV,
"com.sun.xml.bind" % "jaxb-impl" % jaxbV, "com.sun.xml.bind" % "jaxb-impl" % jaxbV,
"ch.qos.logback" % "logback-classic" % "1.0.7", "ch.qos.logback" % "logback-classic" % "1.0.7",
"io.dropwizard.metrics" % "metrics-jvm" % "3.2.2" "io.dropwizard.metrics" % "metrics-jvm" % "3.2.2",
"co.fs2" %% "fs2-core" % "1.0.0"
) )
lazy val dynamoDBDependencies = Seq( lazy val dynamoDBDependencies = Seq(