如何在scala3中编写内联for循环?

ubbxdtey  于 2023-08-05  发布在  Scala
关注(0)|答案(2)|浏览(95)

我用scala写过一些性能关键的代码,遇到的代码如下:

var index = start - 1
while {index += 1; index < end} do something

字符串
然后我想写一个内联版本,以避免再次使用var indexwhile { .. }

inline def forward(start: Int, end: Int)(inline body: Int => Unit): Unit =
    if start < end then
        body(start)
        forward(start + 1, end)(body)


我假设forward(start, end){ something }将生成与手写while {}几乎相同的代码。
但是我遇到一个有线的compile message让我很困惑:

object InlineFor:
    val n: Int = 1

    inline def forward(start: Int, end: Int)(inline body: Int => Unit): Unit =
        if start < end then
            body(start)
            forward(start + 1, end)(body)

    def main(args: Array[String]): Unit =
        forward(0, 1)(println)  // This is OK
        forward(0, n)(println)  // This is failed


超过了连续内联的最大数量(32),这可能是由递归内联方法引起的?您可以使用-Xmax-inlines来更改限制。
密码有什么错误吗?3x非常多

w6lpcovy

w6lpcovy1#

错误消息说明了一切:你有一个内联递归方法。
这只支持已知的迭代次数(在编译时)和编译器标志-Xmax-inlines定义的最大次数。
如果不知道代码中的n,编译器就无法内联代码。
请注意,您的代码中根本不需要inline,只需使用@tailrec注解您的方法以强制它是尾部递归的,编译器将生成优化的代码

9vw9lbht

9vw9lbht2#

如果将n的类型更改为1,就可以解决这个问题:

object InlineFor:
    val n: 1 = 1

    inline def forward(start: Int, end: Int)(inline body: Int => Unit): Unit =
        if start < end then
            body(start)
            forward(start + 1, end)(body)

    def main(args: Array[String]): Unit =
        forward(0, n)(println)  // ok again

字符串
原因是编译器需要知道startend * 在类型级别 * 上的值,以正确确定if start < end
然而:我不认为内联递归是你正在寻找的。如果调用forward(0, 32)(println),您将再次遇到相同的编译器错误,因为这样就真的达到了递归内联的默认限制(32)。
你应该意识到(大大简化了)每次调用内联函数都会导致调用被函数体替换。因此,结果字节码将包含body字节码的n个示例,而不是循环。(这就是为什么递归内联调用有限制的原因。
如果您使用Intellij IDEA:在“视图”菜单下有一个“显示字节码”选项。我建议您使用它,并观察如果您增加n的值,您的字节码将如何增长。

相关问题