tailcalls.scala

来自「JAVA 语言的函数式编程扩展」· SCALA 代码 · 共 331 行

SCALA
331
字号
/* NSC -- new scala compiler * Copyright 2005-2007 LAMP/EPFL * @author Iulian Dragos */// $Id: TailCalls.scala 14116 2008-02-23 05:53:02Z spoon $package scala.tools.nsc.transformimport scala.tools.nsc.symtab.Flags/** Perform tail recursive call elimination. * *  @author Iulian Dragos *  @version 1.0 */abstract class TailCalls extends Transform                         /* with JavaLogging() */ {  // inherits abstract value `global' and class `Phase' from Transform  import global._                  // the global environment  import definitions._             // standard classes and methods  import typer.{typed, atOwner}    // methods to type trees  import posAssigner.atPos         // for filling in tree positions   val phaseName: String = "tailcalls"  def newTransformer(unit: CompilationUnit): Transformer =    new TailCallElimination(unit)  /** Create a new phase which applies transformer */  override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev)  /** The phase defined by this transform */  class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) {    def apply(unit: global.CompilationUnit) {      if (!(settings.debuginfo.value == "notailcalls")) {        newTransformer(unit).transformUnit(unit);      }    }  }  /**   * A Tail Call Transformer   *   * @author     Erik Stenman, Iulian Dragos   * @version    1.1   *   * What it does:   * <p>   *   Finds method calls in tail-position and replaces them with jumps.   *   A call is in a tail-position if it is the last instruction to be   *   executed in the body of a method.  This is done by recursing over   *   the trees that may contain calls in tail-position (trees that can't   *   contain such calls are not transformed). However, they are not that   *   many.   * </p>   * <p>   *   Self-recursive calls in tail-position are replaced by jumps to a   *   label at the beginning of the method. As the JVM provides no way to   *   jump from a method to another one, non-recursive calls in   *   tail-position are not optimized.   * </p>   * <p>   *   A method call is self-recursive if it calls the current method and   *   the method is final (otherwise, it could   *   be a call to an overridden method in a subclass). Furthermore, If   *   the method has type parameters, the call must contain these   *   parameters as type arguments. Recursive calls on a different instance   *   are optimized. Since 'this' is not a local variable, a dummy local val   *   is added and used as a label parameter. The backend knows to load   *   the corresponding argument in the 'this' (local at index 0). This dummy local   *   is never used and should be cleand up by dead code elmination (when enabled).   * </p>   * <p>   *   This phase has been moved before pattern matching to catch more   *   of the common cases of tail recursive functions. This means that   *   more cases should be taken into account (like nested function, and   *   pattern cases).   * </p>   * <p>   *   If a method contains self-recursive calls, a label is added to at   *   the beginning of its body and the calls are replaced by jumps to   *   that label.   * </p>   * <p>   *   Assumes: <code>Uncurry</code> has been run already, and no multiple   *            parameter lists exit.   * </p>   */  class TailCallElimination(unit: CompilationUnit) extends Transformer {    class Context {      /** The current method */      var currentMethod: Symbol = NoSymbol      /** The current tail-call label */      var label: Symbol = NoSymbol      /** The expected type arguments of self-recursive calls */      var tparams: List[Symbol] = Nil      /** Tells whether we are in a (possible) tail position */      var tailPos = false      /** Is the label accessed? */      var accessed = false      def this(that: Context) = {        this()        this.currentMethod = that.currentMethod        this.label         = that.label        this.tparams       = that.tparams        this.tailPos       = that.tailPos        this.accessed      = that.accessed      }      /** Create a new method symbol for the current method and store it in        * the label field.        */      def makeLabel(): Unit = {        label = currentMethod.newLabel(currentMethod.pos, "_" + currentMethod.name)        accessed = false      }      override def toString(): String = (        "" + currentMethod.name + " tparams: " + tparams + " tailPos: " + tailPos +        " accessed: " + accessed + "\nLabel: " + label + "\nLabel type: " + label.info      )    }    private def mkContext(that: Context) = new Context(that)    private def mkContext(that: Context, tp: Boolean): Context = {      val t = mkContext(that)      t.tailPos = tp      t    }    private var ctx: Context = new Context()    /** Rewrite this tree to contain no tail recursive calls */    def transform(tree: Tree, nctx: Context): Tree = {      val oldCtx = ctx      ctx = nctx      val t = transform(tree)      this.ctx = oldCtx      t    }    override def transform(tree: Tree): Tree = {      tree match {        case DefDef(mods, name, tparams, vparams, tpt, rhs) =>          log("Entering DefDef: " + name)          val newCtx = mkContext(ctx)          newCtx.currentMethod = tree.symbol          newCtx.makeLabel()          newCtx.label.setInfo(MethodType(currentClass.tpe :: tree.symbol.tpe.paramTypes, tree.symbol.tpe.finalResultType))          newCtx.tailPos = true          val t1 = if (newCtx.currentMethod.isFinal ||                       newCtx.currentMethod.enclClass.hasFlag(Flags.MODULE)) {            newCtx.tparams = Nil            log("  Considering " + name + " for tailcalls")            tree.symbol.tpe match {              case PolyType(tpes, restpe) =>                newCtx.tparams = tparams map (_.symbol)                newCtx.label.setInfo(                  newCtx.label.tpe.substSym(tpes, tparams map (_.symbol)))              case _ => ()            }            //println("label.tpe: " + newCtx.label.tpe)            var newRHS = transform(rhs, newCtx);            if (newCtx.accessed) {              log("Rewrote def " + newCtx.currentMethod)              val newThis = newCtx.currentMethod.newValue(tree.pos, nme.THIS)                 .setInfo(currentClass.tpe)                 .setFlag(Flags.SYNTHETIC)              newRHS =                  typed(atPos(tree.pos)(Block(List(                    ValDef(newThis, This(currentClass))),                    LabelDef(newCtx.label,                              newThis :: (List.flatten(vparams) map (_.symbol)),                              newRHS))));              copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);            } else              copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);          } else {            copy.DefDef(tree, mods, name, tparams, vparams, tpt, transform(rhs, newCtx))          }          log("Leaving DefDef: " + name)          t1        case EmptyTree => tree        case PackageDef(name, stats) =>          super.transform(tree)        case ClassDef(_, name, _, _) =>          log("Entering class " + name)          val res = super.transform(tree)          log("Leaving class " + name)          res        case ValDef(mods, name, tpt, rhs) => super.transform(tree)        case LabelDef(name, params, rhs) => super.transform(tree)        case Template(parents, self, body) =>          super.transform(tree)        case Block(stats, expr) =>          copy.Block(tree,                     transformTrees(stats, mkContext(ctx, false)),                     transform(expr))        case CaseDef(pat, guard, body) =>          copy.CaseDef(tree, pat, guard, transform(body))        case Sequence(_) | Alternative(_) |             Star(_)     | Bind(_, _) =>          throw new RuntimeException("We should've never gotten inside a pattern")        case Function(vparams, body) =>          tree          //throw new RuntimeException("Anonymous function should not exist at this point. at: " + unit.position(tree.pos));        case Assign(lhs, rhs) =>          super.transform(tree)        case If(cond, thenp, elsep) =>          copy.If(tree, cond, transform(thenp), transform(elsep))        case Match(selector, cases) => //super.transform(tree);          copy.Match(tree, transform(selector, mkContext(ctx, false)), transformTrees(cases).asInstanceOf[List[CaseDef]])        case Return(expr) => super.transform(tree)        case Try(block, catches, finalizer) =>           if (finalizer == EmptyTree)             super.transform(tree)          else            copy.Try(tree, transform(block, mkContext(ctx, false)), // recursive calls are not in tail position if there is non-empty finally clause                           transformTrees(catches, ctx).asInstanceOf[List[CaseDef]],                           transform(finalizer, ctx))        case Throw(expr) => super.transform(tree)        case New(tpt) => super.transform(tree)        case Typed(expr, tpt) => super.transform(tree)        case Apply(tapply @ TypeApply(fun, targs), vargs) =>          lazy val defaultTree = copy.Apply(tree, tapply, transformTrees(vargs, mkContext(ctx, false)))          if ( ctx.currentMethod.isFinal &&               ctx.tailPos &&               isSameTypes(ctx.tparams, targs map (_.tpe.typeSymbol)) &&               isRecursiveCall(fun)) {            fun match {              case Select(receiver, _) => if (!forMSIL) rewriteTailCall(fun, receiver :: transformTrees(vargs, mkContext(ctx, false))) else defaultTree              case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(vargs, mkContext(ctx, false)))            }          } else            defaultTree        case TypeApply(fun, args) =>           super.transform(tree)        case Apply(fun, args) if (fun.symbol == definitions.Boolean_or ||                                  fun.symbol == definitions.Boolean_and) =>          copy.Apply(tree, fun, transformTrees(args))                  case Apply(fun, args) =>          lazy val defaultTree = copy.Apply(tree, fun, transformTrees(args, mkContext(ctx, false)))          if (ctx.currentMethod.isFinal &&               ctx.tailPos &&               isRecursiveCall(fun)) {            fun match {              case Select(receiver, _) => if (!forMSIL) rewriteTailCall(fun, receiver :: transformTrees(args, mkContext(ctx, false))) else defaultTree              case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(args, mkContext(ctx, false)))            }          } else            defaultTree                    case Super(qual, mix) =>          tree        case This(qual) =>          tree        case Select(qualifier, selector) =>          tree        case Ident(name) =>          tree        case Literal(value) =>          tree        case TypeTree() =>          tree        case _ =>          tree      }    }    def transformTrees(trees: List[Tree], nctx: Context): List[Tree] =      trees map ((tree) => transform(tree, nctx))    private def rewriteTailCall(fun: Tree, args: List[Tree]): Tree = {      log("Rewriting tail recursive method call at: " +                      (fun.pos))      ctx.accessed = true      //println("fun: " + fun + " args: " + args)      typed(atPos(fun.pos)(        Apply(Ident(ctx.label), args)))    }    private def isSameTypes(ts1: List[Symbol], ts2: List[Symbol]): Boolean = {      def isSameType(t1: Symbol, t2: Symbol) = {        t1 == t2      }      List.forall2(ts1, ts2)(isSameType)    }    /** Returns <code>true</code> if the fun tree refers to the same method as     *  the one saved in <code>ctx</code>.     *      *  @param fun the expression that is applied     *  @return    <code>true</code> if the tree symbol refers to the innermost      *             enclosing method     */    private def isRecursiveCall(fun: Tree): Boolean =      (fun.symbol eq ctx.currentMethod)  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?