uncurry.scala

来自「JAVA 语言的函数式编程扩展」· SCALA 代码 · 共 600 行 · 第 1/2 页

SCALA
600
字号
/* NSC -- new Scala compiler * Copyright 2005-2007 LAMP/EPFL * @author */// $Id: UnCurry.scala 14083 2008-02-20 15:48:44Z odersky $package scala.tools.nsc.transformimport symtab.Flags._import scala.collection.mutable.{HashMap, HashSet}import scala.tools.nsc.util.Position/*<export>*//** - uncurry all symbol and tree types (@see UnCurryPhase) *  - for every curried parameter list:  (ps_1) ... (ps_n) ==> (ps_1, ..., ps_n) *  - for every curried application: f(args_1)...(args_n) ==> f(args_1, ..., args_n) *  - for every type application: f[Ts] ==> f[Ts]() unless followed by parameters *  - for every use of a parameterless function: f ==> f()  and  q.f ==> q.f() *  - for every def-parameter:  x: => T ==> x: () => T *  - for every use of a def-parameter: x ==> x.apply() *  - for every argument to a def parameter `x: => T':  *      if argument is not a reference to a def parameter: *        convert argument `e' to (expansion of) `() => e' *  - for every repeated parameter `x: T*' --> x: Seq[T]. *  - for every argument list that corresponds to a repeated parameter *       (a_1, ..., a_n) => (Seq(a_1, ..., a_n)) *  - for every argument list that is an escaped sequence *       (a_1:_*) => (a_1)g *  - convert implicit method types to method types *  - convert non-trivial catches in try statements to matches *  - convert non-local returns to throws with enclosing try statements. *//*</export>*/abstract class UnCurry extends InfoTransform with TypingTransformers {  import global._                  // the global environment  import definitions._             // standard classes and methods  import posAssigner.atPos         // for filling in tree positions   val phaseName: String = "uncurry"  def newTransformer(unit: CompilationUnit): Transformer = new UnCurryTransformer(unit)  override def changesBaseClasses = false// ------ Type transformation --------------------------------------------------------//@MAT: uncurry and uncurryType fully expand type aliases in their input and output// note: don't normalize higher-kined types -- @M TODO: maybe split those uses of normalize? // OTOH, should be a problem as calls to normalize only occur on types with kind * in principle (in well-typed programs)  private def expandAlias(tp: Type): Type = if (!tp.isHigherKinded) tp.normalize else tp    private val uncurry: TypeMap = new TypeMap {    def apply(tp0: Type): Type = {      val tp = expandAlias(tp0)      tp match {        case MethodType(formals, MethodType(formals1, restpe)) =>          apply(MethodType(formals ::: formals1, restpe))        case MethodType(formals, ExistentialType(tparams, restpe @ MethodType(_, _))) =>          assert(false, "unexpected curried method types with intervening exitential")           tp0        case mt: ImplicitMethodType =>          apply(MethodType(mt.paramTypes, mt.resultType))        case PolyType(List(), restpe) =>          apply(MethodType(List(), restpe))        case PolyType(tparams, restpe) =>          PolyType(tparams, apply(MethodType(List(), restpe)))        /*         case TypeRef(pre, sym, List(arg1, arg2)) if (arg1.typeSymbol == ByNameParamClass) =>         assert(sym == FunctionClass(1))         apply(typeRef(pre, definitions.ByNameFunctionClass, List(expandAlias(arg1.typeArgs(0)), arg2)))         */        case TypeRef(pre, sym, List(arg)) if (sym == ByNameParamClass) =>          apply(functionType(List(), arg))        case TypeRef(pre, sym, args) if (sym == RepeatedParamClass) =>          apply(rawTypeRef(pre, SeqClass, args))        case _ =>          expandAlias(mapOver(tp))      }    }  }  private val uncurryType = new TypeMap {    def apply(tp0: Type): Type = {      val tp = expandAlias(tp0)      tp match {        case ClassInfoType(parents, decls, clazz) =>          val parents1 = List.mapConserve(parents)(uncurry)          if (parents1 eq parents) tp          else ClassInfoType(parents1, decls, clazz) // @MAT normalize in decls??        case PolyType(_, _) =>          mapOver(tp)        case _ =>          tp      }    }  }    /** - return symbol's transformed type,    *  - if symbol is a def parameter with transformed type T, return () => T   *   * @MAT: starting with this phase, the info of every symbol will be normalized   */  def transformInfo(sym: Symbol, tp: Type): Type =     if (sym.isType) uncurryType(tp) else uncurry(tp)  /** Traverse tree omitting local method definitions.   *  If a `return' is encountered, set `returnFound' to true.   *  Used for MSIL only.   */  private object lookForReturns extends Traverser {    var returnFound = false    override def traverse(tree: Tree): Unit =  tree match {      case Return(_) => returnFound = true      case DefDef(_, _, _, _, _, _) => ;      case _ => super.traverse(tree)    }    def found(tree: Tree) = {      returnFound = false      traverse(tree)      returnFound    }  }  class UnCurryTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {    private var needTryLift = false    private var inPattern = false    private var inConstructorFlag = 0L    private val byNameArgs = new HashSet[Tree]    private val noApply = new HashSet[Tree]    override def transform(tree: Tree): Tree = try { //debug      postTransform(mainTransform(tree))    } catch {      case ex: Throwable =>        Console.println("exception when traversing " + tree)        throw ex    }    /* Is tree a reference `x' to a call by name parameter that neeeds to be converted to      * x.apply()? Note that this is not the case if `x' is used as an argument to another     * call by name parameter.     */    def isByNameRef(tree: Tree): Boolean =      tree.isTerm && tree.hasSymbol &&      tree.symbol.tpe.typeSymbol == ByNameParamClass &&       !byNameArgs.contains(tree)    /** Uncurry a type of a tree node.     *  This function is sensitive to whether or not we are in a pattern -- when in a pattern     *  additional parameter sections of a case class are skipped.     */    def uncurryTreeType(tp: Type): Type = tp match {      case MethodType(formals, MethodType(formals1, restpe)) if (inPattern) =>        uncurryTreeType(MethodType(formals, restpe))      case _ =>        uncurry(tp)    }// ------- Handling non-local returns -------------------------------------------------    /** The type of a non-local return expression with given argument type */    private def nonLocalReturnExceptionType(argtype: Type) =      appliedType(NonLocalReturnExceptionClass.typeConstructor, List(argtype))    /** A hashmap from method symbols to non-local return keys */    private val nonLocalReturnKeys = new HashMap[Symbol, Symbol]    /** Return non-local return key for given method */    private def nonLocalReturnKey(meth: Symbol) = nonLocalReturnKeys.get(meth) match {      case Some(k) => k      case None =>        val k = meth.newValue(meth.pos, unit.fresh.newName("nonLocalReturnKey"))          .setFlag(SYNTHETIC).setInfo(ObjectClass.tpe)        nonLocalReturnKeys(meth) = k        k    }    /** Generate a non-local return throw with given return expression from given method.     *  I.e. for the method's non-local return key, generate:     *     *    throw new NonLocalReturnException(key, expr)     */    private def nonLocalReturnThrow(expr: Tree, meth: Symbol) =      localTyper.typed {        Throw(          New(            TypeTree(nonLocalReturnExceptionType(expr.tpe)),            List(List(Ident(nonLocalReturnKey(meth)), expr))))      }    /** Transform (body, key) to:     *     *  {     *    val key = new Object()     *    try {      *      body      *    } catch {     *      case ex: NonLocalReturnException[_] =>      *        if (ex.key().eq(key)) ex.value()     *        else throw ex     *    }     *  }     */    private def nonLocalReturnTry(body: Tree, key: Symbol, meth: Symbol) = {      localTyper.typed {        val extpe = nonLocalReturnExceptionType(meth.tpe.finalResultType)        val ex = meth.newValue(body.pos, nme.ex) setInfo extpe        val pat = Bind(ex,                        Typed(Ident(nme.WILDCARD),                              AppliedTypeTree(Ident(NonLocalReturnExceptionClass),                                             List(Bind(nme.WILDCARD.toTypeName,                                                       EmptyTree)))))        val rhs =          If(            Apply(              Select(                Apply(Select(Ident(ex), "key"), List()),                Object_eq),              List(Ident(key))),            Apply(              TypeApply(                Select(                  Apply(Select(Ident(ex), "value"), List()),                  Any_asInstanceOf),                List(TypeTree(meth.tpe.finalResultType))),              List()),            Throw(Ident(ex)))        val keyDef = ValDef(key, New(TypeTree(ObjectClass.tpe), List(List())))        val tryCatch = Try(body, List(CaseDef(pat, EmptyTree, rhs)), EmptyTree)        Block(List(keyDef), tryCatch)      }    }// ------ Transforming anonymous functions and by-name-arguments ----------------    /** Undo eta expansion for parameterless and nullaray methods */    def deEta(fun: Function): Tree = fun match {      case Function(List(), Apply(expr, List())) if treeInfo.isPureExpr(expr) =>         if (expr.hasSymbol && expr.symbol.hasFlag(LAZY))          fun        else          expr      case Function(List(), expr) if isByNameRef(expr) =>         noApply += expr         expr      case _ =>         fun    }            /*  Transform a function node (x_1,...,x_n) => body of type FunctionN[T_1, .., T_N, R] to     *     *    class $anon() extends Object() with FunctionN[T_1, .., T_N, R] with ScalaObject {     *      def apply(x_1: T_1, ..., x_N: T_n): R = body     *    }     *    new $anon()     *     *  transform a function node (x => body) of type PartialFunction[T, R] where     *    body = x match { case P_i if G_i => E_i }_i=1..n     *  to:     *     *    class $anon() extends Object() with PartialFunction[T, R] with ScalaObject {     *      def apply(x: T): R = (x: @unchecked) match {     *        { case P_i if G_i => E_i }_i=1..n     *      def isDefinedAt(x: T): boolean = (x: @unchecked) match {     *        case P_1 if G_1 => true     *        ...     *        case P_n if G_n => true     *        case _ => false     *      }     *    }     *    new $anon()     *       *  However, if one of the patterns P_i if G_i is a default pattern, generate instead     *     *      def isDefinedAt(x: T): boolean = true     */    def transformFunction(fun: Function): Tree = {      val fun1 = deEta(fun)      if (fun1 ne fun) fun1      else {        val anonClass = fun.symbol.owner.newAnonymousFunctionClass(fun.pos)           .setFlag(FINAL | SYNTHETIC | inConstructorFlag)        val formals = fun.tpe.typeArgs.init        val restpe = fun.tpe.typeArgs.last        anonClass setInfo ClassInfoType(          List(ObjectClass.tpe, fun.tpe, ScalaObjectClass.tpe), newScope, anonClass);        val applyMethod = anonClass.newMethod(fun.pos, nme.apply)          .setFlag(FINAL).setInfo(MethodType(formals, restpe));        anonClass.info.decls enter applyMethod;        for (vparam <- fun.vparams) vparam.symbol.owner = applyMethod;        new ChangeOwnerTraverser(fun.symbol, applyMethod).traverse(fun.body);        def applyMethodDef(body: Tree) =           DefDef(Modifiers(FINAL), nme.apply, List(), List(fun.vparams), TypeTree(restpe), body)            .setSymbol(applyMethod)        def mkUnchecked(tree: Tree) = tree match {          case Match(selector, cases) =>            atPos(tree.pos) {              Match(                Annotated(Annotation(New(TypeTree(UncheckedClass.tpe), List(List())), List()), selector),

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?