package sbt

/** Describes restrictions on concurrent execution for a set of tasks.
*
* @tparam A the type of a task
*/
trait ConcurrentRestrictions[A]
{
	/** Internal state type used to describe a set of tasks. */
	type G

	/** Representation of zero tasks.*/
	def empty: G
	
	/** Updates the description `g` to include a new task `a`.*/
	def add(g: G, a: A): G

	/** Updates the description `g` to remove a previously added task `a`.*/
	def remove(g: G, a: A): G

	/**
	* Returns true if the tasks described by `g` are allowed to execute concurrently.
	* The methods in this class must obey the following laws:
	*
	* 1. forall g: G, a: A; valid(g) => valid(remove(g,a))
	* 2. forall a: A; valid(add(empty, a))
	* 3. forall g: G, a: A; valid(g) <=> valid(remove(add(g, a), a))
	* 4. (implied by 1,2,3) valid(empty)
	* 5. forall g: G, a: A, b: A; !valid(add(g,a)) => !valid(add(add(g,b), a))
	*/
	def valid(g: G): Boolean
}

	import java.util.{LinkedList,Queue}
	import java.util.concurrent.{Executor, Executors, ExecutorCompletionService}
	import annotation.tailrec

object ConcurrentRestrictions
{
	/** A ConcurrentRestrictions instance that places no restrictions on concurrently executing tasks.
	* @param zero the constant placeholder used for t */
	def unrestricted[A]: ConcurrentRestrictions[A] =
		new ConcurrentRestrictions[A]
		{
			type G = Unit
			def empty = ()
			def add(g: G, a: A) = ()
			def remove(g: G, a: A) = ()
			def valid(g: G) = true
		}

	def limitTotal[A](i: Int): ConcurrentRestrictions[A] =
	{
		assert(i >= 1, "Maximum must be at least 1 (was " + i + ")")
		new ConcurrentRestrictions[A]
		{
			type G = Int
			def empty = 0
			def add(g: Int, a: A) = g + 1
			def remove(g: Int, a: A) = g - 1
			def valid(g: Int) = g <= i
		}
	}

	/** A key object used for associating information with a task.*/
	final case class Tag(name: String)

	val tagsKey = AttributeKey[TagMap]("tags", "Attributes restricting concurrent execution of tasks.")

	/** A standard tag describing the number of tasks that do not otherwise have any tags.*/
	val Untagged = Tag("untagged")

	/** A standard tag describing the total number of tasks. */
	val All = Tag("all")

	type TagMap = Map[Tag, Int]

	/** Implements concurrency restrictions on tasks based on Tags.
	* @tparam A type of a task
	* @param get extracts tags from a task
	* @param validF defines whether a set of tasks are allowed to execute concurrently based on their merged tags*/
	def tagged[A](get: A => TagMap, validF: TagMap => Boolean): ConcurrentRestrictions[A] =
		new ConcurrentRestrictions[A]
		{
			type G = TagMap
			def empty = Map.empty
			def add(g: TagMap, a: A) = merge(g, a, get)(_ + _)
			def remove(g: TagMap, a: A) = merge(g, a, get)(_ - _)
			def valid(g: TagMap) = validF(g)
		}

	private[this] def merge[A](m: TagMap, a: A, get: A => TagMap)(f: (Int,Int) => Int): TagMap =
	{
		val base = merge(m, get(a))(f)
		val un = if(base.isEmpty) update(base, Untagged, 1)(f) else base
		update(un, All, 1)(f)
	}

	private[this] def update[A,B](m: Map[A,B], a: A, b: B)(f: (B,B) => B): Map[A,B] =
	{
		val newb =
			(m get a) match {
				case Some(bv) => f(bv,b)
				case None => b
			}
		m.updated(a,newb)
	}
	private[this] def merge[A,B](m: Map[A,B], n: Map[A,B])(f: (B,B) => B): Map[A,B] =
		(m /: n) { case (acc, (a,b)) => update(acc, a, b)(f) }

	/** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution.
	* @return a pair, with _1 being the CompletionService and _2 a function to shutdown the service.
	* @tparam A the task type
	* @tparam G describes a set of tasks
	* @tparam R the type of data that will be computed by the CompletionService. */
	def completionService[A,R](tags: ConcurrentRestrictions[A], warn: String => Unit): (CompletionService[A,R], () => Unit) =
	{
		val pool = Executors.newCachedThreadPool()
		(completionService[A,R](pool, tags, warn), () => pool.shutdownNow() )
	}

	/** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution
	* and using the provided Executor to manage execution on threads. */
	def completionService[A,R](backing: Executor, tags: ConcurrentRestrictions[A], warn: String => Unit): CompletionService[A,R] =
	{
		/** Represents submitted work for a task.*/
		final class Enqueue(val node: A, val work: () => R)

		new CompletionService[A,R]
		{
			/** Backing service used to manage execution on threads once all constraints are satisfied. */
			private[this] val jservice = new ExecutorCompletionService[R](backing)
			/** The description of the currently running tasks, used by `tags` to manage restrictions.*/
			private[this] var tagState = tags.empty
			/** The number of running tasks. */
			private[this] var running = 0
			/** Tasks that cannot be run yet because they cannot execute concurrently with the currently running tasks.*/
			private[this] val pending = new LinkedList[Enqueue]

			def submit(node: A,  work: () => R): Unit = synchronized
			{
				val newState = tags.add(tagState, node)
					// if the new task is allowed to run concurrently with the currently running tasks,
					//   submit it to be run by the backing j.u.c.CompletionService
				if(tags valid newState)
				{ 
					tagState = newState
					submitValid( node, work )
				}
				else
				{
					if(running == 0) errorAddingToIdle()
					pending.add( new Enqueue(node, work) )
				}
			}
			private[this] def submitValid(node: A, work: () => R) =
			{
				running += 1
				val wrappedWork = () => try work() finally cleanup(node)
				CompletionService.submit(wrappedWork, jservice)
			}
			private[this] def cleanup(node: A): Unit = synchronized
			{
				running -= 1
				tagState = tags.remove(tagState, node)
				if(!tags.valid(tagState)) warn("Invalid restriction: removing a completed node from a valid system must result in a valid system.")
				submitValid(new LinkedList)
			}
			private[this] def errorAddingToIdle() = warn("Invalid restriction: adding a node to an idle system must be allowed.")

			/** Submits pending tasks that are now allowed to executed. */
			@tailrec private[this] def submitValid(tried: Queue[Enqueue]): Unit =
				if(pending.isEmpty)
				{
					if(!tried.isEmpty)
					{
						if(running == 0) errorAddingToIdle()
						pending.addAll(tried)
					}
				}
				else
				{
					val next = pending.remove()
					val newState = tags.add(tagState, next.node)
					if(tags.valid(newState))
					{
						tagState = newState
						submitValid(next.node, next.work)
					}
					else
						tried.add(next)
					submitValid(tried)
				}

			def take(): R = jservice.take().get()
		}
	}
}