Non-local returns in Scala 3.2

I have a complete example of a loop abstraction that erases to labelled jumps:

import scala.util.boundary, boundary.break

object Loops:

  type ExitToken = Unit { type Exit }
  type ContinueToken = Unit { type Continue }

  type Exit = boundary.Label[ExitToken]
  type Continue = boundary.Label[ContinueToken]

  object loop:
    inline def apply(inline op: (Exit, Continue) ?=> Unit): Unit =
      boundary[ExitToken]:
        while true do
          boundary[ContinueToken]:
            op
    inline def exit()(using Exit) = break(().asInstanceOf[ExitToken])
    inline def continue()(using Continue) = break(().asInstanceOf[ContinueToken])

note here I am using refinements on the Unit type to generate two unique types, but the fact that it erases to Unit is possibly unstable

and usage here to print odd numbers up to a limit

import Loops.*
@main def oddsUpToLimit(limit: Int) =
  var i = 0
  loop:
    i += 1
    if i == limit then loop.exit()
    if i % 2 == 0 then loop.continue()
    println(i)

CFR decompiler output:

public void oddsUpToLimit(int limit) {
    int i = 0;
    while (true && ++i != limit) {
        if (i % 2 == 0) continue;
        Predef$.MODULE$.println((Object)BoxesRunTime.boxToInteger((int)i));
    }
}
3 Likes

The same abstraction can encode the loops of tail-recursive functions. For example, factorial1 and factorial2 are basically equivalent in the following code:

//> using scala "3.3.0-RC2"

import scala.annotation.tailrec

import scala.util.boundary

@tailrec
def factorial1(n: Long, acc: Long): Long =
  if n == 0L then acc
  else factorial1(n - 1L, acc * n)

def factorial2(n0: Long, acc0: Long): Long = boundary { result ?=>
  var n = n0
  var acc = acc0
  while true do
    boundary[Unit] { tailcall ?=>
      boundary.break {
        if n == 0L then acc
        else
          acc = acc * n
          n = n - 1L
          boundary.break()(using tailcall)
      } (using result)
    }
  throw AssertionError("unreachable")
}

@main def Factorial(): Unit = {
  println(factorial1(5L, 1L))
  println(factorial2(5L, 1L))
}

You can compare the bytecode for both:

  public long factorial1(long, long);
    Code:
       0: lload_1
       1: lconst_0
       2: lcmp
       3: ifne          8
       6: lload_3
       7: lreturn
       8: lload_1
       9: lconst_1
      10: lsub
      11: lstore        5
      13: lload_3
      14: lload_1
      15: lmul
      16: lstore        7
      18: lload         5
      20: lstore_1
      21: lload         7
      23: lstore_3
      24: goto          0

  public long factorial2(long, long);
    Code:
       0: lload_1
       1: lstore        5
       3: lload_3
       4: lstore        7
       6: iconst_1
       7: ifeq          36
      10: lload         5
      12: lconst_0
      13: lcmp
      14: ifne          20
      17: lload         7
      19: lreturn
      20: lload         7
      22: lload         5
      24: lmul
      25: lstore        7
      27: lload         5
      29: lconst_1
      30: lsub
      31: lstore        5
      33: goto          6
      36: new           #35                 // class java/lang/AssertionError
      39: dup
      40: ldc           #37                 // String unreachable
      42: invokespecial #40                 // Method java/lang/AssertionError."<init>":(Ljava/lang/Object;)V
      45: athrow

There are some extra locals, and the unreachable throw, but otherwise it’s the same.

For the full story, that is no coincidence: the compiler actually uses the same underlying abstraction (a Labeled block with Returns from that Labeled) to compile optimized boundary/break and @tailrec defs!