How is tail recursion implemented in scala and why tail couldn't be some other function?


#1

The second part of the question is what I’m interested in and I guess that first part will give me the answer.
So tail-recursive function calls itself as the last operation in one of it’s branches. As a consequence, the memory it was consuming is given to it’s successor. Right?
But if the last operation in a function is a call to some other function the memory will not be freed! Why is this so? Knowing that all the vals you used in previous function are not needed anymore same as in the first case when function called itself!
Thanks!


#2

Tail recursion can be changed to a loop and that’s what Scala compiler does. Tail calls in general require stack frame replacement support from the underlying platform, but Java doesn’t provide that.


#3

I’m not an expert on Computer Science but the answer is of a theoretical nature.

I guess technically you could always remove a stack-level each time you have another function call in tail-position, but that would throw away information on the stack-trace that it’s valuable both for debugging errors and exceptions and possibly to understand what happens when you return from a method.

My unconfirmed opinion is that the case of recursion is pretty special in this regard because the stack is bounded, while the compiler can’t figure out what will be the depth of the recursion at runtime, as it usually depends on the inputs and the recursion logic.

This means that recursive calls suffer from the risk of “blowing the stack” more regularly than other regular functions and that’s why it makes sense to optimize this case.

You may reason that in general it makes sense to optimize for the case of every tail-call than is “cyclic” even if not directly self-recursive (.e.g 2 functions calling each other) and in fact there’s a technique called “trampolining” that’s invented exactly to generalize the tail-recursive optimization to this latter case.


#4

Why settle for believing people when you can believe the compiler?

scala> :paste
// Entering paste mode (ctrl-D to finish)

def tailRecursiveSum(ints: List[Int], aggregate: Int): Int =
  if (ints.isEmpty) aggregate
  else {
    val tail = ints.tail
    val sum = aggregate + ints.head
    tailRecursiveSum(tail, sum)
  }

// Exiting paste mode, now interpreting.

tailRecursiveSum: (ints: List[Int], aggregate: Int)Int

scala> :javap tailRecursiveSum
Constant pool:
<elided for brevity>
{

  public int tailRecursiveSum(scala.collection.immutable.List<java.lang.Object>, int);
    descriptor: (Lscala/collection/immutable/List;I)I
    flags: ACC_PUBLIC
    Code:
      stack=2, locals=6, args_size=3
         0: aload_1                           /* push the list from slot 1 (ints) onto the stack */
         1: invokevirtual #27                 // Method scala/collection/immutable/List.isEmpty:()Z
         4: ifeq          11                  /*
         7: iload_2                            * if List.isEmpty, push aggregate from slot 2 onto the stack and goto return
         8: goto          40                   */
        11: aload_1                                /* push the list from slot 1 onto the stack */
        12: invokevirtual #31                 // Method scala/collection/immutable/List.tail:()Ljava/lang/Object;
        15: checkcast     #23                 // class scala/collection/immutable/List
        18: astore        4                        /* store the tail in slot 4 (tail)  */
        20: iload_2                                /* push aggregate from slot 2 onto the stack */
        21: aload_1                               /* push list from slot 1 onto the stack */
        22: invokevirtual #34                 // Method scala/collection/immutable/List.head:()Ljava/lang/Object;
        25: invokestatic  #40                 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
        28: iadd                                     /* add the head value (from 25) to the stack slot above it (the aggregate from 20), the list was consumed in 22 producing the head value */
        29: istore        5                        /* store the sum in slot 5 */
        31: aload         4                       /* push the tail from slot 4 (tail) (stored @ 18) onto the stack */
        33: iload         5                        /* push the sum in slot 5 (sum) onto the stack */
        35: istore_2                              /* store sum in slot 2 (aggregate) */
        36: astore_1                             /* store tail in slot 1 (ints) */
        37: goto          0                        /* go back to the top. Slot 1 is now tail instead of the first paramter to the method, and 
        40: ireturn
      LocalVariableTable:
        Start  Length  Slot  Name   Signature
           18      22     4  tail   Lscala/collection/immutable/List;
           29      11     5   sum   I
            0      41     0  this   L$line4/$read$$iw$$iw$;
            0      41     1  ints   Lscala/collection/immutable/List;
            0      41     2 aggregate   I
      LineNumberTable:
        line 12: 0
        line 14: 11
        line 15: 20
        line 16: 31
      StackMapTable: number_of_entries = 3
        frame_type = 0 /* same */
        frame_type = 10 /* same */
        frame_type = 92 /* same_locals_1_stack_item */
          stack = [ int ]
    Signature: #56                          // (Lscala/collection/immutable/List<Ljava/lang/Object;>;I)I
    MethodParameters:
      Name                           Flags
      ints                           final
      aggregate                      final

The /* comments are added by me to explain what’s going on, and also not aligned properly because the editor here doesn’t use a monospace font.

This optimization hinges on being able to set the state of the method so that it looks like it’s just been called with new parameters, and jumping back to the top, pretending it’s a new call with new parameters.

But you can’t just goto into a different method, so this trick doesn’t work when making a cross method call.