uncurry.scala
来自「JAVA 语言的函数式编程扩展」· SCALA 代码 · 共 600 行 · 第 1/2 页
SCALA
600 行
/* NSC -- new Scala compiler * Copyright 2005-2007 LAMP/EPFL * @author */// $Id: UnCurry.scala 14083 2008-02-20 15:48:44Z odersky $package scala.tools.nsc.transformimport symtab.Flags._import scala.collection.mutable.{HashMap, HashSet}import scala.tools.nsc.util.Position/*<export>*//** - uncurry all symbol and tree types (@see UnCurryPhase) * - for every curried parameter list: (ps_1) ... (ps_n) ==> (ps_1, ..., ps_n) * - for every curried application: f(args_1)...(args_n) ==> f(args_1, ..., args_n) * - for every type application: f[Ts] ==> f[Ts]() unless followed by parameters * - for every use of a parameterless function: f ==> f() and q.f ==> q.f() * - for every def-parameter: x: => T ==> x: () => T * - for every use of a def-parameter: x ==> x.apply() * - for every argument to a def parameter `x: => T': * if argument is not a reference to a def parameter: * convert argument `e' to (expansion of) `() => e' * - for every repeated parameter `x: T*' --> x: Seq[T]. * - for every argument list that corresponds to a repeated parameter * (a_1, ..., a_n) => (Seq(a_1, ..., a_n)) * - for every argument list that is an escaped sequence * (a_1:_*) => (a_1)g * - convert implicit method types to method types * - convert non-trivial catches in try statements to matches * - convert non-local returns to throws with enclosing try statements. *//*</export>*/abstract class UnCurry extends InfoTransform with TypingTransformers { import global._ // the global environment import definitions._ // standard classes and methods import posAssigner.atPos // for filling in tree positions val phaseName: String = "uncurry" def newTransformer(unit: CompilationUnit): Transformer = new UnCurryTransformer(unit) override def changesBaseClasses = false// ------ Type transformation --------------------------------------------------------//@MAT: uncurry and uncurryType fully expand type aliases in their input and output// note: don't normalize higher-kined types -- @M TODO: maybe split those uses of normalize? // OTOH, should be a problem as calls to normalize only occur on types with kind * in principle (in well-typed programs) private def expandAlias(tp: Type): Type = if (!tp.isHigherKinded) tp.normalize else tp private val uncurry: TypeMap = new TypeMap { def apply(tp0: Type): Type = { val tp = expandAlias(tp0) tp match { case MethodType(formals, MethodType(formals1, restpe)) => apply(MethodType(formals ::: formals1, restpe)) case MethodType(formals, ExistentialType(tparams, restpe @ MethodType(_, _))) => assert(false, "unexpected curried method types with intervening exitential") tp0 case mt: ImplicitMethodType => apply(MethodType(mt.paramTypes, mt.resultType)) case PolyType(List(), restpe) => apply(MethodType(List(), restpe)) case PolyType(tparams, restpe) => PolyType(tparams, apply(MethodType(List(), restpe))) /* case TypeRef(pre, sym, List(arg1, arg2)) if (arg1.typeSymbol == ByNameParamClass) => assert(sym == FunctionClass(1)) apply(typeRef(pre, definitions.ByNameFunctionClass, List(expandAlias(arg1.typeArgs(0)), arg2))) */ case TypeRef(pre, sym, List(arg)) if (sym == ByNameParamClass) => apply(functionType(List(), arg)) case TypeRef(pre, sym, args) if (sym == RepeatedParamClass) => apply(rawTypeRef(pre, SeqClass, args)) case _ => expandAlias(mapOver(tp)) } } } private val uncurryType = new TypeMap { def apply(tp0: Type): Type = { val tp = expandAlias(tp0) tp match { case ClassInfoType(parents, decls, clazz) => val parents1 = List.mapConserve(parents)(uncurry) if (parents1 eq parents) tp else ClassInfoType(parents1, decls, clazz) // @MAT normalize in decls?? case PolyType(_, _) => mapOver(tp) case _ => tp } } } /** - return symbol's transformed type, * - if symbol is a def parameter with transformed type T, return () => T * * @MAT: starting with this phase, the info of every symbol will be normalized */ def transformInfo(sym: Symbol, tp: Type): Type = if (sym.isType) uncurryType(tp) else uncurry(tp) /** Traverse tree omitting local method definitions. * If a `return' is encountered, set `returnFound' to true. * Used for MSIL only. */ private object lookForReturns extends Traverser { var returnFound = false override def traverse(tree: Tree): Unit = tree match { case Return(_) => returnFound = true case DefDef(_, _, _, _, _, _) => ; case _ => super.traverse(tree) } def found(tree: Tree) = { returnFound = false traverse(tree) returnFound } } class UnCurryTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { private var needTryLift = false private var inPattern = false private var inConstructorFlag = 0L private val byNameArgs = new HashSet[Tree] private val noApply = new HashSet[Tree] override def transform(tree: Tree): Tree = try { //debug postTransform(mainTransform(tree)) } catch { case ex: Throwable => Console.println("exception when traversing " + tree) throw ex } /* Is tree a reference `x' to a call by name parameter that neeeds to be converted to * x.apply()? Note that this is not the case if `x' is used as an argument to another * call by name parameter. */ def isByNameRef(tree: Tree): Boolean = tree.isTerm && tree.hasSymbol && tree.symbol.tpe.typeSymbol == ByNameParamClass && !byNameArgs.contains(tree) /** Uncurry a type of a tree node. * This function is sensitive to whether or not we are in a pattern -- when in a pattern * additional parameter sections of a case class are skipped. */ def uncurryTreeType(tp: Type): Type = tp match { case MethodType(formals, MethodType(formals1, restpe)) if (inPattern) => uncurryTreeType(MethodType(formals, restpe)) case _ => uncurry(tp) }// ------- Handling non-local returns ------------------------------------------------- /** The type of a non-local return expression with given argument type */ private def nonLocalReturnExceptionType(argtype: Type) = appliedType(NonLocalReturnExceptionClass.typeConstructor, List(argtype)) /** A hashmap from method symbols to non-local return keys */ private val nonLocalReturnKeys = new HashMap[Symbol, Symbol] /** Return non-local return key for given method */ private def nonLocalReturnKey(meth: Symbol) = nonLocalReturnKeys.get(meth) match { case Some(k) => k case None => val k = meth.newValue(meth.pos, unit.fresh.newName("nonLocalReturnKey")) .setFlag(SYNTHETIC).setInfo(ObjectClass.tpe) nonLocalReturnKeys(meth) = k k } /** Generate a non-local return throw with given return expression from given method. * I.e. for the method's non-local return key, generate: * * throw new NonLocalReturnException(key, expr) */ private def nonLocalReturnThrow(expr: Tree, meth: Symbol) = localTyper.typed { Throw( New( TypeTree(nonLocalReturnExceptionType(expr.tpe)), List(List(Ident(nonLocalReturnKey(meth)), expr)))) } /** Transform (body, key) to: * * { * val key = new Object() * try { * body * } catch { * case ex: NonLocalReturnException[_] => * if (ex.key().eq(key)) ex.value() * else throw ex * } * } */ private def nonLocalReturnTry(body: Tree, key: Symbol, meth: Symbol) = { localTyper.typed { val extpe = nonLocalReturnExceptionType(meth.tpe.finalResultType) val ex = meth.newValue(body.pos, nme.ex) setInfo extpe val pat = Bind(ex, Typed(Ident(nme.WILDCARD), AppliedTypeTree(Ident(NonLocalReturnExceptionClass), List(Bind(nme.WILDCARD.toTypeName, EmptyTree))))) val rhs = If( Apply( Select( Apply(Select(Ident(ex), "key"), List()), Object_eq), List(Ident(key))), Apply( TypeApply( Select( Apply(Select(Ident(ex), "value"), List()), Any_asInstanceOf), List(TypeTree(meth.tpe.finalResultType))), List()), Throw(Ident(ex))) val keyDef = ValDef(key, New(TypeTree(ObjectClass.tpe), List(List()))) val tryCatch = Try(body, List(CaseDef(pat, EmptyTree, rhs)), EmptyTree) Block(List(keyDef), tryCatch) } }// ------ Transforming anonymous functions and by-name-arguments ---------------- /** Undo eta expansion for parameterless and nullaray methods */ def deEta(fun: Function): Tree = fun match { case Function(List(), Apply(expr, List())) if treeInfo.isPureExpr(expr) => if (expr.hasSymbol && expr.symbol.hasFlag(LAZY)) fun else expr case Function(List(), expr) if isByNameRef(expr) => noApply += expr expr case _ => fun } /* Transform a function node (x_1,...,x_n) => body of type FunctionN[T_1, .., T_N, R] to * * class $anon() extends Object() with FunctionN[T_1, .., T_N, R] with ScalaObject { * def apply(x_1: T_1, ..., x_N: T_n): R = body * } * new $anon() * * transform a function node (x => body) of type PartialFunction[T, R] where * body = x match { case P_i if G_i => E_i }_i=1..n * to: * * class $anon() extends Object() with PartialFunction[T, R] with ScalaObject { * def apply(x: T): R = (x: @unchecked) match { * { case P_i if G_i => E_i }_i=1..n * def isDefinedAt(x: T): boolean = (x: @unchecked) match { * case P_1 if G_1 => true * ... * case P_n if G_n => true * case _ => false * } * } * new $anon() * * However, if one of the patterns P_i if G_i is a default pattern, generate instead * * def isDefinedAt(x: T): boolean = true */ def transformFunction(fun: Function): Tree = { val fun1 = deEta(fun) if (fun1 ne fun) fun1 else { val anonClass = fun.symbol.owner.newAnonymousFunctionClass(fun.pos) .setFlag(FINAL | SYNTHETIC | inConstructorFlag) val formals = fun.tpe.typeArgs.init val restpe = fun.tpe.typeArgs.last anonClass setInfo ClassInfoType( List(ObjectClass.tpe, fun.tpe, ScalaObjectClass.tpe), newScope, anonClass); val applyMethod = anonClass.newMethod(fun.pos, nme.apply) .setFlag(FINAL).setInfo(MethodType(formals, restpe)); anonClass.info.decls enter applyMethod; for (vparam <- fun.vparams) vparam.symbol.owner = applyMethod; new ChangeOwnerTraverser(fun.symbol, applyMethod).traverse(fun.body); def applyMethodDef(body: Tree) = DefDef(Modifiers(FINAL), nme.apply, List(), List(fun.vparams), TypeTree(restpe), body) .setSymbol(applyMethod) def mkUnchecked(tree: Tree) = tree match { case Match(selector, cases) => atPos(tree.pos) { Match( Annotated(Annotation(New(TypeTree(UncheckedClass.tpe), List(List())), List()), selector),
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?