/* sbt -- Simple Build Tool
 * Copyright 2008 Mark Harrah
 *
 * Partially based on exit trapping in Nailgun by Pete Kirkham,
 * copyright 2004, Martian Software, Inc
 * licensed under Apache 2.0 License.
 */
package sbt

import scala.collection.Set
import scala.reflect.Manifest

/** This provides functionality to catch System.exit calls to prevent the JVM from terminating.
* This is useful for executing user code that may call System.exit, but actually exiting is
* undesirable.  This file handles the call to exit by disposing all top-level windows and interrupting
* all user started threads.  It does not stop the threads and does not call shutdown hooks.  It is
* therefore inappropriate to use this with code that requires shutdown hooks or creates threads that
* do not terminate.  This category of code should only be called by forking the JVM. */
object TrapExit
{
	/** Executes the given thunk in a context where System.exit(code) throws
	* a custom SecurityException, which is then caught and the exit code returned.
	* Otherwise, 0 is returned.  No other exceptions are handled by this method.*/
	def apply(execute: => Unit, log: Logger): Int =
	{
		log.debug("Starting sandboxed run...")
		
		/** Take a snapshot of the threads that existed before execution in order to determine
		* the threads that were created by 'execute'.*/
		val originalThreads = allThreads
		val code = new ExitCode
		def executeMain =
			try { execute }
			catch
			{
				case e: TrapExitSecurityException => throw e
				case x =>
					code.set(1) //exceptions in the main thread cause the exit code to be 1
					throw x
			}
		val customThreadGroup = new ExitThreadGroup(new ExitHandler(originalThreads, code, log))
		val executionThread = new Thread(customThreadGroup, "run-main") { override def run() { executeMain } }
		
		val originalSecurityManager = System.getSecurityManager
		try
		{
			val newSecurityManager = new TrapExitSecurityManager(originalSecurityManager, customThreadGroup)
			System.setSecurityManager(newSecurityManager)

			executionThread.start()

			log.debug("Waiting for threads to exit or System.exit to be called.")
			waitForExit(originalThreads, log)
			log.debug("Interrupting remaining threads (should be all daemons).")
			interruptAll(originalThreads) // should only be daemon threads left now
			log.debug("Sandboxed run complete..")
			code.value.getOrElse(0)
		}
		catch { case e: InterruptedException => cancel(executionThread, allThreads, log) }
		finally System.setSecurityManager(originalSecurityManager)
	}
	private[this] def cancel(executionThread: Thread, originalThreads: Set[Thread], log: Logger): Int =
	{
		log.warn("Run canceled.")
		executionThread.interrupt()
		stopAll(originalThreads)
		1
	}
	
	 // wait for all non-daemon threads to terminate
	private def waitForExit(originalThreads: Set[Thread], log: Logger)
	{
		var daemonsOnly = true
		processThreads(originalThreads, thread =>
			if(!thread.isDaemon)
			{
				daemonsOnly = false
				waitOnThread(thread, log)
			}
		)
		if(!daemonsOnly)
			waitForExit(originalThreads, log)
	}
	/** Waits for the given thread to exit. */
	private def waitOnThread(thread: Thread, log: Logger)
	{
		log.debug("Waiting for thread " + thread.getName + " to exit")
		thread.join
		log.debug("\tThread " + thread.getName + " exited.")
	}
	/** Returns the exit code of the System.exit that caused the given Exception, or rethrows the exception
	* if its cause was not calling System.exit.*/
	private def exitCode(e: Throwable) =
		withCause[TrapExitSecurityException, Int](e)
			{exited => exited.exitCode}
			{other => throw other}
	/** Recurses into the causes of the given exception looking for a cause of type CauseType.  If one is found, `withType` is called with that cause.
	*  If not, `notType` is called with the root cause.*/
	private def withCause[CauseType <: Throwable, T](e: Throwable)(withType: CauseType => T)(notType: Throwable => T)(implicit mf: Manifest[CauseType]): T =
	{
		val clazz = mf.erasure
		if(clazz.isInstance(e))
			withType(e.asInstanceOf[CauseType])
		else
		{
			val cause = e.getCause
			if(cause == null)
				notType(e)
			else
				withCause(cause)(withType)(notType)(mf)
		}
	}
	
	/** Returns all threads that are not in the 'system' thread group and are not the AWT implementation
	* thread (AWT-XAWT, AWT-Windows, ...)*/
	private def allThreads: Set[Thread] =
	{
		import collection.JavaConversions._
		Thread.getAllStackTraces.keySet.filter(thread => !isSystemThread(thread))
	}
	/** Returns true if the given thread is in the 'system' thread group and is an AWT thread other than
	* AWT-EventQueue or AWT-Shutdown.*/
	private def isSystemThread(t: Thread) =
	{
		val name = t.getName
		if(name.startsWith("AWT-"))
			!(name.startsWith("AWT-EventQueue") || name.startsWith("AWT-Shutdown"))
		else
		{
			val group = t.getThreadGroup
			(group != null) && (group.getName == "system")
		}
	}
	/** Calls the provided function for each thread in the system as provided by the 
	* allThreads function except those in ignoreThreads.*/
	private def processThreads(ignoreThreads: Set[Thread], process: Thread => Unit)
	{
		allThreads.filter(thread => !ignoreThreads.contains(thread)).foreach(process)
	}
	/** Handles System.exit by disposing all frames and calling interrupt on all user threads */
	private def stopAll(originalThreads: Set[Thread])
	{
		disposeAllFrames()
		interruptAll(originalThreads)
	}
	private def disposeAllFrames()
	{
		val allFrames = java.awt.Frame.getFrames
		if(allFrames.length > 0)
		{
			allFrames.foreach(_.dispose) // dispose all top-level windows, which will cause the AWT-EventQueue-* threads to exit
			Thread.sleep(2000) // AWT Thread doesn't exit immediately, so wait to interrupt it
		}
	}
	// interrupt all threads that appear to have been started by the user
	private def interruptAll(originalThreads: Set[Thread]): Unit =
		processThreads(originalThreads, safeInterrupt)
	// interrupts the given thread, but first replaces the exception handler so that the InterruptedException is not printed
	private def safeInterrupt(thread: Thread)
	{
		if(!thread.getName.startsWith("AWT-"))
		{
			thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler))
			thread.interrupt
		}
	}
	// an uncaught exception handler that swallows InterruptedExceptions and otherwise defers to originalHandler
	private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler
	{
		def uncaughtException(thread: Thread, e: Throwable)
		{
			withCause[InterruptedException, Unit](e)
				{interrupted => ()}
				{other => originalHandler.uncaughtException(thread, e) }
			thread.setUncaughtExceptionHandler(originalHandler)
		}
	}
	/** An uncaught exception handler that delegates to the original uncaught exception handler except when
	* the cause was a call to System.exit (which generated a SecurityException)*/
	private final class ExitHandler(originalThreads: Set[Thread], codeHolder: ExitCode, log: Logger) extends Thread.UncaughtExceptionHandler
	{
		def uncaughtException(t: Thread, e: Throwable)
		{
			try
			{
				codeHolder.set(exitCode(e)) // will rethrow e if it was not because of a call to System.exit
				stopAll(originalThreads)
			}
			catch
			{
				case _ =>
					log.error("(" + t.getName + ") " + e.toString)
					log.trace(e)
			}
		}
	}
	private final class ExitThreadGroup(handler: Thread.UncaughtExceptionHandler) extends ThreadGroup("trap.exit")
	{
		override def uncaughtException(t: Thread, e: Throwable) = handler.uncaughtException(t, e)
	}
}
private final class ExitCode extends NotNull
{
	private var code: Option[Int] = None
	def set(c: Int): Unit = synchronized { code = code orElse Some(c) }
	def value: Option[Int] = synchronized { code }
}
///////  These two classes are based on similar classes in Nailgun
/** A custom SecurityManager to disallow System.exit. */
private final class TrapExitSecurityManager(delegateManager: SecurityManager, group: ThreadGroup) extends SecurityManager
{
	import java.security.Permission
	override def checkExit(status: Int)
	{
		val stack = Thread.currentThread.getStackTrace
		if(stack == null || stack.exists(isRealExit))
			throw new TrapExitSecurityException(status)
	}
	/** This ensures that only actual calls to exit are trapped and not just calls to check if exit is allowed.*/
	private def isRealExit(element: StackTraceElement): Boolean =
		element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit"
	override def checkPermission(perm: Permission)
	{
		if(delegateManager != null)
			delegateManager.checkPermission(perm)
	}
	override def checkPermission(perm: Permission, context: AnyRef)
	{
		if(delegateManager != null)
			delegateManager.checkPermission(perm, context)
	}
	override def getThreadGroup = group
}
/** A custom SecurityException that tries not to be caught.*/
private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException
{
	private var accessAllowed = false
	def allowAccess
	{
		accessAllowed = true
	}
	override def printStackTrace = ifAccessAllowed(super.printStackTrace)
	override def toString = ifAccessAllowed(super.toString)
	override def getCause = ifAccessAllowed(super.getCause)
	override def getMessage = ifAccessAllowed(super.getMessage)
	override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace)
	override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage)
	private def ifAccessAllowed[T](f: => T): T =
	{
		if(accessAllowed)
			f
		else
			throw this
	}
}