package sbt
import scala.collection.Set
import scala.reflect.Manifest
object TrapExit
{
def apply(execute: => Unit, log: Logger): Int =
{
log.debug("Starting sandboxed run...")
val originalThreads = allThreads
val code = new ExitCode
def executeMain =
try { execute }
catch
{
case e: TrapExitSecurityException => throw e
case x =>
code.set(1)
throw x
}
val customThreadGroup = new ExitThreadGroup(new ExitHandler(originalThreads, code, log))
val executionThread = new Thread(customThreadGroup, "run-main") { override def run() { executeMain } }
val originalSecurityManager = System.getSecurityManager
try
{
val newSecurityManager = new TrapExitSecurityManager(originalSecurityManager, customThreadGroup)
System.setSecurityManager(newSecurityManager)
executionThread.start()
log.debug("Waiting for threads to exit or System.exit to be called.")
waitForExit(originalThreads, log)
log.debug("Interrupting remaining threads (should be all daemons).")
interruptAll(originalThreads)
log.debug("Sandboxed run complete..")
code.value.getOrElse(0)
}
catch { case e: InterruptedException => cancel(executionThread, allThreads, log) }
finally System.setSecurityManager(originalSecurityManager)
}
private[this] def cancel(executionThread: Thread, originalThreads: Set[Thread], log: Logger): Int =
{
log.warn("Run canceled.")
executionThread.interrupt()
stopAll(originalThreads)
1
}
private def waitForExit(originalThreads: Set[Thread], log: Logger)
{
var daemonsOnly = true
processThreads(originalThreads, thread =>
if(!thread.isDaemon)
{
daemonsOnly = false
waitOnThread(thread, log)
}
)
if(!daemonsOnly)
waitForExit(originalThreads, log)
}
private def waitOnThread(thread: Thread, log: Logger)
{
log.debug("Waiting for thread " + thread.getName + " to exit")
thread.join
log.debug("\tThread " + thread.getName + " exited.")
}
private def exitCode(e: Throwable) =
withCause[TrapExitSecurityException, Int](e)
{exited => exited.exitCode}
{other => throw other}
private def withCause[CauseType <: Throwable, T](e: Throwable)(withType: CauseType => T)(notType: Throwable => T)(implicit mf: Manifest[CauseType]): T =
{
val clazz = mf.erasure
if(clazz.isInstance(e))
withType(e.asInstanceOf[CauseType])
else
{
val cause = e.getCause
if(cause == null)
notType(e)
else
withCause(cause)(withType)(notType)(mf)
}
}
private def allThreads: Set[Thread] =
{
import collection.JavaConversions._
Thread.getAllStackTraces.keySet.filter(thread => !isSystemThread(thread))
}
private def isSystemThread(t: Thread) =
{
val name = t.getName
if(name.startsWith("AWT-"))
!(name.startsWith("AWT-EventQueue") || name.startsWith("AWT-Shutdown"))
else
{
val group = t.getThreadGroup
(group != null) && (group.getName == "system")
}
}
private def processThreads(ignoreThreads: Set[Thread], process: Thread => Unit)
{
allThreads.filter(thread => !ignoreThreads.contains(thread)).foreach(process)
}
private def stopAll(originalThreads: Set[Thread])
{
disposeAllFrames()
interruptAll(originalThreads)
}
private def disposeAllFrames()
{
val allFrames = java.awt.Frame.getFrames
if(allFrames.length > 0)
{
allFrames.foreach(_.dispose)
Thread.sleep(2000)
}
}
private def interruptAll(originalThreads: Set[Thread]): Unit =
processThreads(originalThreads, safeInterrupt)
private def safeInterrupt(thread: Thread)
{
if(!thread.getName.startsWith("AWT-"))
{
thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler))
thread.interrupt
}
}
private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler
{
def uncaughtException(thread: Thread, e: Throwable)
{
withCause[InterruptedException, Unit](e)
{interrupted => ()}
{other => originalHandler.uncaughtException(thread, e) }
thread.setUncaughtExceptionHandler(originalHandler)
}
}
private final class ExitHandler(originalThreads: Set[Thread], codeHolder: ExitCode, log: Logger) extends Thread.UncaughtExceptionHandler
{
def uncaughtException(t: Thread, e: Throwable)
{
try
{
codeHolder.set(exitCode(e))
stopAll(originalThreads)
}
catch
{
case _ =>
log.error("(" + t.getName + ") " + e.toString)
log.trace(e)
}
}
}
private final class ExitThreadGroup(handler: Thread.UncaughtExceptionHandler) extends ThreadGroup("trap.exit")
{
override def uncaughtException(t: Thread, e: Throwable) = handler.uncaughtException(t, e)
}
}
private final class ExitCode extends NotNull
{
private var code: Option[Int] = None
def set(c: Int): Unit = synchronized { code = code orElse Some(c) }
def value: Option[Int] = synchronized { code }
}
private final class TrapExitSecurityManager(delegateManager: SecurityManager, group: ThreadGroup) extends SecurityManager
{
import java.security.Permission
override def checkExit(status: Int)
{
val stack = Thread.currentThread.getStackTrace
if(stack == null || stack.exists(isRealExit))
throw new TrapExitSecurityException(status)
}
private def isRealExit(element: StackTraceElement): Boolean =
element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit"
override def checkPermission(perm: Permission)
{
if(delegateManager != null)
delegateManager.checkPermission(perm)
}
override def checkPermission(perm: Permission, context: AnyRef)
{
if(delegateManager != null)
delegateManager.checkPermission(perm, context)
}
override def getThreadGroup = group
}
private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException
{
private var accessAllowed = false
def allowAccess
{
accessAllowed = true
}
override def printStackTrace = ifAccessAllowed(super.printStackTrace)
override def toString = ifAccessAllowed(super.toString)
override def getCause = ifAccessAllowed(super.getCause)
override def getMessage = ifAccessAllowed(super.getMessage)
override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace)
override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage)
private def ifAccessAllowed[T](f: => T): T =
{
if(accessAllowed)
f
else
throw this
}
}