/* NSC -- new Scala compiler
 * Copyright 2005-2009 LAMP/EPFL
 * @author Burak Emir
 */
// $Id: PatternNodes.scala 18387 2009-07-24 15:28:37Z odersky $

package scala.tools.nsc
package matching

import scala.tools.nsc.util.{Position, NoPosition}

/**
 *  @author Burak Emir
 */
trait PatternNodes extends ast.TreeDSL 
{ 
  self: transform.ExplicitOuter =>
  
  import global.{ typer => _, _ }
  import analyzer.Typer
  import symtab.Flags
  import Types._
  import CODE._
  import definitions.{ ConsClass, ListClass, AnyRefClass, EqualsPatternClass, ListModule }
  
  type TypeComparison = (Type, Type) => Boolean
  
  // Tests on Types
  def isEquals(t: Type)       = cond(t) { case TypeRef(_, EqualsPatternClass, _) => true }
  def isAnyRef(t: Type)       = t <:< AnyRefClass.tpe
  def isCaseClass(t: Type)    = t.typeSymbol hasFlag Flags.CASE
  
  // Comparisons on types
  // def sameSymbols: TypeComparison  = _.typeSymbol eq _.typeSymbol
  // def samePrefix: TypeComparison   = _.prefix =:= _.prefix
  // def isSubErased: TypeComparison  = (t1, t2) => cond((t1, t2)) {
  //   case (_: TypeRef, _: TypeRef) => !t1.isArray && samePrefix(t1,t2) && (sameSymbols(t1,t2) || isSubClass(t1, t2))
  // }
  
  // def isSubClass: TypeComparison   = (t1, t2) => t1.baseClasses exists (_ eq t2.typeSymbol)
  // def isSubType: TypeComparison    = (t1, t2) => isSubClass(t1, t2) && (t1 <:< t2)
  // def isPatMatch: TypeComparison   = (t1, t2) => isSubType(t1, t2)
  
  def decodedEqualsType(tpe: Type) =
    condOpt(tpe) { case TypeRef(_, EqualsPatternClass, List(arg))  => arg } getOrElse (tpe)

  // If we write isSubtypeOf like: 
  //
  //   = t1.baseTypeSeq exists (_ =:= t2)
  //
  // ..then all tests pass except for test/files/run/bug1434.scala, which involves:
  // class A[T], B extends A[Any], C extends B ; ... match { case _: A[_] => .. ; case _: C => .. ; case _: B => .. }
  // and a match error is thrown, which is interesting because either of C or B should match.
  def isSubtypeOf(t1: Type, t2: Type) = t1.baseTypeSeq exists (p => cmpSymbols(p, t2))
  def cmpSymbols(t1: Type, t2: Type)  = t1.typeSymbol eq t2.typeSymbol
  
  case class TypeComp(x: Type, y: Type) {
    def xIsaY = x <:< y
    def yIsaX = y <:< x
    def eqSymbol = cmpSymbols(x, y)
    def eqPrefix = x.prefix =:= y.prefix
    
    object erased {
      import Types._
      /** an approximation of _tp1 <:< tp2 that ignores _ types. this code is wrong,
       *  ideally there is a better way to do it, and ideally defined in Types.scala
       */
      private def cmpErased(t1: Type, t2: Type) = (t1, t2) match {
        case (_: TypeRef, _: TypeRef) => !t1.isArray && eqPrefix && (eqSymbol || isSubtypeOf(t1, t2))
        case _ => false
      }
      
      def xIsaY = cmpErased(x, y)
      def yIsaX = cmpErased(y, x)
    }
  }
  
  object Types {
    import definitions._
    implicit def enrichType(_tpe: Type): RichType = new RichType(_tpe)
    
    class RichType(_tpe: Type) {
      /* equality checks for named constant patterns like "Foo()" are encoded as "_:<equals>[Foo().type]"
       * and later compiled to "if(Foo() == scrutinee) ...". This method extracts type information from
       * such an encoded type, which is used in optimization. If the argument is not an encoded equals
       * test, it is returned as is.
       */   
      private def tpeWRTEquality(t: Type) = t match {
        case TypeRef(_, EqualsPatternClass, List(arg))  => arg
        case _                                          => t
      }
      lazy val tpe = tpeWRTEquality(_tpe)

      // These tests for final classes can inspect the typeSymbol
      private def is(s: Symbol) = tpe.typeSymbol eq s      
      def       isInt = is(IntClass)
      def      isChar = is(CharClass)
      def   isBoolean = is(BooleanClass)
      def   isNothing = is(NothingClass)
      def     isArray = is(ArrayClass)
            
      def cmp(other: Type): TypeComp = TypeComp(tpe, tpeWRTEquality(other))
      
      def coversSym(sym: Symbol) = {
        lazy val lmoc = sym.linkedModuleOfClass
        val symtpe = 
          if ((sym hasFlag Flags.MODULE) && (lmoc ne NoSymbol))
            singleType(sym.tpe.prefix, lmoc)   // e.g. None, Nil
          else sym.tpe

        (tpe.typeSymbol == sym) ||
        (symtpe <:< tpe) ||
        (symtpe.parents exists (x => cmpSymbols(x, tpe))) || // e.g. Some[Int] <: Option[&b]
        ((tpe.prefix memberType sym) <:< tpe)  // outer, see combinator.lexical.Scanner
      }
    }
  
    // used as argument to `EqualsPatternClass'
    case class PseudoType(o: Tree) extends SimpleTypeProxy {
      override def underlying: Type = o.tpe
      override def safeToString: String = "PseudoType("+o+")"
    }
  }

  final def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree)

  def makeBind(vs: List[Symbol], pat: Tree): Tree = vs match {
    case Nil      => pat
    case x :: xs  => Bind(x, makeBind(xs, pat)) setType pat.tpe
  }  
    
  private def mkBind(vs: List[Symbol], tpe: Type, arg: Tree) =
    makeBind(vs, Typed(arg, TypeTree(tpe)) setType tpe)

  def mkTypedBind(vs: List[Symbol], tpe: Type)      = mkBind(vs, tpe, WILD(tpe))
  def mkEmptyTreeBind(vs: List[Symbol], tpe: Type)  = mkBind(vs, tpe, EmptyTree)
  def mkEqualsRef(xs: List[Type])                   = typeRef(NoPrefix, EqualsPatternClass, xs)
  
  /** For folding a list into a well-typed x :: y :: etc :: tree. */
  private def listFolder(tpe: Type) = {
    val MethodType(_, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe
    val consRef                             = typeRef(pre, sym, List(tpe))
    val listRef                             = typeRef(pre, ListClass, List(tpe))
    
    def fold(x: Tree, xs: Tree) = x match {
      case sp @ Strip(_, _: Star) => makeBind(definedVars(sp), WILD(sp.tpe))
      case _                      =>
        val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy")
        val consType    = MethodType(dummyMethod newSyntheticValueParams List(tpe, listRef), consRef)
      
        Apply(TypeTree(consType), List(x, xs)) setType consRef
    }
    
    fold _
  }

  def normalizedListPattern(pats: List[Tree], tptArg: Type): Tree =
    pats.foldRight(gen.mkNil)(listFolder(tptArg))
  
  // An Apply that's a constructor pattern (case class)
  // foo match { case C() => true }
  object Apply_CaseClass {
    def unapply(x: Any) = condOpt(x) {
      case x @ Apply(fn, args) if fn.isType => (x.tpe, args)
    }
  }

  // No-args Apply where fn is not a type - looks like, case object with prefix?
  // 
  // class Pip { 
  //   object opcodes { case object EmptyInstr }
  //   def bop(x: Any) = x match { case opcodes.EmptyInstr => true }
  // }
  object Apply_Value {
    def unapply(x: Any) = condOpt(x) { 
      case x @ Apply(fn, Nil) if !fn.isType => (x.tpe.prefix, x.symbol)
    }
  }
  
  // No-args Apply for all the other cases
  // val Bop = Nil
  // def foo(x: Any) = x match { case Bop => true }
  object Apply_Function {
    def isApplyFunction(t: Apply) = cond(t) {
      case Apply_Value(_, _)        => true
      case x if !isCaseClass(x.tpe) => true
    }
    def unapply(x: Any) = condOpt(x) { 
      case x @ Apply(fn, Nil) if isApplyFunction(x) => fn
    }
  }

  // unapplySeq extractor
  // val List(x,y) = List(1,2)
  object UnapplySeq {
    private object TypeApp {
      def unapply(x: Any) = condOpt(x) { 
        case TypeApply(sel @ Select(stor, nme.unapplySeq), List(tpe)) if stor.symbol eq ListModule => tpe
      }
    }
    def unapply(x: UnApply) = condOpt(x) {
      case UnApply(Apply(TypeApp(tptArg), _), List(ArrayValue(_, xs))) => (tptArg, xs)
    }
  }
    
  // unapply extractor
  // val Pair(_,x) = Pair(1,2)
  object __UnApply {
    private def paramType(fn: Tree) = fn.tpe match { case m: MethodType => m.paramTypes.head }
    def unapply(x: Tree) = condOpt(x) {
      case Strip(vs, UnApply(Apply(fn, _), args)) => (vs, paramType(fn), args)
    }
  }

  // break a pattern down into bound variables and underlying tree.
  object Strip { 
    private def strip(t: Tree, syms: List[Symbol] = Nil): (Tree, List[Symbol]) = t match {
      case b @ Bind(_, pat) => strip(pat, b.symbol :: syms)
      case _                => (t, syms)
    }
    def unapply(x: Tree): Option[(List[Symbol], Tree)] = Some(strip(x).swap)
  }
  
  object Stripped {
    def unapply(x: Tree): Option[Tree] = Some(unbind(x))
  }

  final def definedVars(x: Tree): List[Symbol] = {
    def vars(x: Tree): List[Symbol] = x match {
      case Apply(_, args)     => args flatMap vars
      case b @ Bind(_, p)     => b.symbol :: vars(p)
      case Typed(p, _)        => vars(p)              // otherwise x @ (_:T)
      case UnApply(_, args)   => args flatMap vars
      case ArrayValue(_, xs)  => xs flatMap vars
      case x                  => Nil
    }
    vars(x) reverse
  }

  /** pvar: the symbol of the pattern variable
   *  tvar: the temporary variable that holds the actual value
   */
  case class Binding(pvar: Symbol, tvar: Symbol) {
    override def toString() = "%s: %s @ %s: %s".format(pvar.name, pvar.tpe, tvar.name, tvar.tpe)
  }
  
  case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] {
    private def castIfNeeded(pvar: Symbol, tvar: Symbol) =
      if (tvar.tpe <:< pvar.tpe) ID(tvar)
      else ID(tvar) AS_ANY pvar.tpe
    
    def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = {
      def newBinding(v: Symbol) = {
        // see bug #1843 for the consequences of not setting info.
        // there is surely a better way to do this, especially since
        // this looks to be the only usage of containsTp anywhere
        // in the compiler, but it suffices for now.
        if (tvar.info containsTp WildcardType)
          tvar setInfo v.info
        
        Binding(v, tvar)
      }
      val newBindings = vs.toList map newBinding
      Bindings(newBindings ++ bindings: _*)
    }
       
    def apply(v: Symbol): Option[Ident] =
      bindings find (_.pvar eq v) map (x => Ident(x.tvar) setType v.tpe)
    
    override def toString() = 
      if (bindings.isEmpty) "" else bindings.mkString(" Bound(", ", ", ")")

    /** The corresponding list of value definitions. */
    final def targetParams(implicit typer: Typer): List[ValDef] = 
      for (Binding(v, t) <- bindings.toList) yield
        VAL(v) === (typer typed castIfNeeded(v, t))
  }

  val NoBinding: Bindings = Bindings()
}