Is this tail recursion?

I am trying to understand tail recursion. Is this an example of a tail recursive function?

  def fib(n: Int): Int = {
    var count = 2
    @annotation.tailrec
    def go(a: Int, b: Int): Int = {
      count += 1
      val c = a + b
      if (n==0) 0
      if (n==1) 1
      if (count==n) c
      else go(b, c)
    }
    go(0, 1)
  }

go is tail-recursive.

It is not a (pure) function, it has the side-effect of modifying the count variable. To deal with these kinds of state variables in pure functions, you usually pass them along as a parameter. It would end up looking like this (which I corrected with else/ifs for the n == 0 / n == 1 cases)

@annotation.tailrec
def go(a: Int, b: Int, count: Int): Int = {
  val count_1 = count + 1
  val c = a + b
  if (n==0) 0
  else if (n==1) 1
  else if (count_1 == n) c
  else go(b, c, count_1)
}
1 Like

The function go is still recursive with the var, as it calls itself. It just isn’t a pure function (in the functional programming sense). But of course recursion is often used instead of loops to avoid mutable state.

Also, go in fib is a tailrecursive function. This is actually ensured by the compiler, as you put the @annotation.tailrecursive above it. If it wasn’t this would cause an error. To check yourself, if a function is tailrecursive, you have to look at all positions, where the function calls itself. If the call is the last thing the function does, i.e. the value returned by the recursive call is the value returned by the outer go call, the function is tail-recursive. If you do further operations on the value after the call, it is no longer in tail position and cannot be optimized. For example, if you’d replace the call in your code with go(b, c) + 1, this would need to do the + 1 operation after the go call returned, so the recursive call wouldn’t be the last operation.

2 Likes

A recursive function is a function that calls itself. His example still does that, and hence is still recursive. The use of a var to count the number of invocations is a detail that doesn’t change that fact. Though I have to admit that I’m really unsure what is going on with this function. I feel like there should be two additional occurrences of else.

It also happens to be tail recursive because the only recursive call happens in the “tail” position. It is the last things that the function does on every branch where it happens.

If you really want to understand tail recursion, it might help to take functions that aren’t tail recursive, understand why they aren’t, and then change them to be tail recursive. For example, start with a basic List length.

def length(lst: List[Int]): Int = lst match {
  case Nil => 0
  case h :: t => 1 + length(t)
}

This is not tail recursive in Scala because after the call to length produces a result, there is an addition operation that needs to occur. To make this tail recursive you need to pass in an extra argument that accumulates the result so that the addition happens on the way down the call stack and not on the way back up.

While I wrote that to be a basic example, I realized how many variations there are to that code. I really wanted to make it generic, because the type held in the list doesn’t matter. I will provide a version using if just in case the reader isn’t familiar with pattern matching on lists.

def length(lst: List[Int]): Int = if(lst.isEmpty) 0 else 1+length(lst.tail)

Whichever version you prefer, the exercise for the learner is to make it tail recursive by including an extra argument of the accumulated value and moving the addition into the calculation of that argument.

Unrelated to your question, the statements if (n==0) 0 and if (n==1) 1 do nothing and can be removed. Maybe you meant to put else in front of some of your ifs?

Thanks for pointing out that there would be 2 else if instead of ifs. I corrected the same but forgot to update this post :slight_smile:

I agree that this function would seem to meet the definition of tail recursive (at least when the missing else clauses or explicit returns are added).

I am not sure how tail call optimization works in the Scala compiler, but I would guess that it might not be able to reuse the stack frame because of the local mutable variable count. So there would seem to be no practical advantage to the tail recursion.

Why? I don’t think the presence or absence of the var has anything to do with it. I personally wouldn’t recommend writing a recursive function that uses a var like this (I generally recommend avoiding vars when possible), but it’s still tail-recursive, and looks like it ought to be correctly optimized…

Tail recursion is implemented by refactoring to a while loop. The @tailrec annotation causes compilation to fail if tail call doesnt get eliminated: that’s the only thing the annotation does, and the reason you use it.

If you want to check what the compiler does, check the decompilation :javap in the repl is a quick and easy way to do that.

2 Likes

(: Okay, I should have checked this specific example. I am just getting back to Scala after several years away. :slight_smile:

I have learned to be careful about making assertions about tail call optimization in languages. Sometimes there are easily overlooked issues that can mean the optimization does not occur. For example, in some languages, the exception mechanism may prevent the optimization because the traceback needs to unroll the stack. In this case, I was concerned about the mutable variable from the outer scope that is incremented in the inner scope of the tail recursive “function” would interfere with the optimization.

When I was programming in (and teaching) Scala several years ago, I adopted the practice of making all variables “val” and then only backing out to “var” if there was a good reason (such as hand-optimizing a recursion to a loop or writing a explicitly imperative mutator for a class. In general, I try to make all functions pure, unless I make an explicit choice to design it otherwise.

As call to go is last computation, go is a tail-recursive function. That is the only criteria for a function to be tail-recursive.

Here is another version for fibonacci:

def fibonacci(n:Int): Int = {
def go(counter: Int, previous: Int, current: Int): Int = {
if(counter == n) previous+current else go(counter + 1, current, previous + current)
}
if(n == 1) 0
else if (n == 2) 1
else go(3, 0, 1)
}

1 Like

No worries. It’s a common misunderstanding that @tailrec tries to do tailrec elminiation. It doesn’t. The compiler does tail recursion elimination regardless of the presence of the annotation, and the annotation checks whether the elimination happened, just like @switch does for checking whether a match is translated to a switch instruction.

I understood @tailrec correctly; I was just unsure whether the compiler would be able to do the TCO in that situation (and had not checked).

That’s the point though: if it wouldn’t be able to, the snippet wouldn’t compile.

This function is a tail recursion. How to prove this ?
As given by other answers if it were not tail recursion then the annotation @tailrec would give the compilation error.

The other way to prove this function is a tail recursion is to use the getStackTrace() method of Thread class as given below.

package tailRecursion

object Fibonaaci {
  def fib(n: Int): Int = {
    var count = 2
    def go(a: Int, b: Int): Int = {
      count += 1
      val c = a + b
      if (n==0) {
        0
      }
      else if (n==1) {
        1
      }
      else if (count==n) {
        val x = Thread.currentThread.getStackTrace
        x.foreach{println _}
        c
      }
      else go(b, c)
    }
    go(0, 1)
  }

  def main(args: Array[String]): Unit = {
    val result = fib(10)
    println(result)
  }
}

The output of the above code is given below.

java.lang.Thread.getStackTrace(Thread.java:1552)
tailRecursion.Fibonaaci$.go$1(Fibonaaci.scala:17)
tailRecursion.Fibonaaci$.fib(Fibonaaci.scala:23)
tailRecursion.Fibonaaci$.main(Fibonaaci.scala:27)
tailRecursion.Fibonaaci.main(Fibonaaci.scala)
34

At the time we call the getStackTrace method (line 17) on the current thread we have only one stack frame used by go method. That is the reason we are only seeing one “go” in the above output.

Let us consider another example which is not a tail recursion. This time let us see the non tail recursion of factorial method.

package tailRecursion

object FactorialNoTailRec {

  def factorial(n:Int) : Int = if (n == 0) {
    val x = Thread.currentThread.getStackTrace
    x.foreach{println _}
    1
  } else {
    n * factorial(n-1)
  }

  def main(args: Array[String]): Unit = {
    println("factorial without tail rec started")
    val result = factorial(10)
    println(result)
  }
}

The output of the above command is given below.

factorial without tail rec started
java.lang.Thread.getStackTrace(Thread.java:1552)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:6)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.factorial(FactorialNoTailRec.scala:10)
tailRecursion.FactorialNoTailRec$.main(FactorialNoTailRec.scala:15)
tailRecursion.FactorialNoTailRec.main(FactorialNoTailRec.scala)
3628800

From the above output it clear that we have 11 stack frames created for factorial method because it is not tail recursion method.

Now let us see the tail recursion version of factorial method.

package tailRecursion

object FactorialTailRec {

     def factorial(n:Int, sum:Int) : Int = if (n == 0) {
       val x = Thread.currentThread.getStackTrace
       x.foreach{println _}
       sum
     } else {
       factorial(n-1, sum * n)
     }

  def main(args: Array[String]): Unit = {
    println("factorial without tail rec started")
    val result = factorial(10, 1)
    println(result)
  }
}

The output of the above tail recursion version of factorial is given below.

factorial without tail rec started
java.lang.Thread.getStackTrace(Thread.java:1552)
tailRecursion.FactorialTailRec$.factorial(FactorialTailRec.scala:6)
tailRecursion.FactorialTailRec$.main(FactorialTailRec.scala:15)
tailRecursion.FactorialTailRec.main(FactorialTailRec.scala)
3628800

The tail recursion version produced the same output which is 3628800. However it uses only 1 stack frame for the factorial method instead of 11.

I tried to find some sort of instrumentation techniques to analyze Java stack memory used by its threads to prove this.

But the best thing I found in our case is the getStackTrace() method to prove this method is tail recursion.

1 Like