在Scala 3中查找lambda捕获值(或它们的类)

oxcyiej7  于 2022-12-26  发布在  Scala
关注(0)|答案(1)|浏览(153)

我正在寻找一种方法来查找Scala 3中lambda捕获的值(或它们的类)(用于序列化-类似于Spark)(我不需要Scala 2支持):

val a = "abc"
val f = () => a + "xyz"
serialize(f) // Should detect a / String as captured value

在运行时做这件事有点容易(在f.getClass.getDeclaredFields上迭代),但我想在编译时做。
我试图在宏中检查lambda的时间,但它被检测为普通scala.Function0,没有任何有趣的信息。
我想知道我是否可以做一些树检查,但我真的希望避免这样做-我觉得我必须复制编译器内部来捕捉所有的边缘情况。

nxagd54h

nxagd54h1#

尝试以下macro

import scala.quoted.*

inline def serialize(x: Any): Unit = ${serializeImpl('x)}

def serializeImpl(x: Expr[Any])(using Quotes): Expr[Unit] = {
  import quotes.reflect.*

  def owners(s: Symbol): List[Symbol] = s :: List.unfold(s)(s1 => Option.when(s1.maybeOwner != Symbol.noSymbol)((s1.maybeOwner, s1.maybeOwner)))

  val symbol = x.asTerm.underlying.symbol
  val rhs = symbol.tree match {
    case ValDef(_, _, Some(rhs)) => rhs
  }

  val traverser = new TreeTraverser {
    override def traverseTree(tree: Tree)(owner: Symbol): Unit = {
      tree match {
        case Ident(name) =>
          val symbol1 = tree.symbol
          val pos1 = symbol1.pos.get
          println(s"identifier: $name, defined inside lambda: ${owners(symbol1).contains(symbol)}, defined in current file: ${pos1.sourceFile == SourceFile.current}")

        case _ =>
      }

      super.traverseTree(tree)(owner)
    }
  }

  traverser.traverseTree(rhs)(rhs.symbol)

  '{()}
}

用法:
x一个一个一个一个x一个一个二个x

相关问题