Need help on understanding generic fold for generic data

Hi all, I just started learning Scala and FP through Essential-Scala ebook that is open source by underscore.io. I’m currently learning about generic, functions and fold method. I stumble upon a question to code
"A Tree of type A is a Node with left and right Tree or a Leaf with an element of type A. Implement this algebraic data type along with a fold method."

My answer for the fold method was totally wrong, I put the type to be A for everything (i.e. no type B written at all) because i thought if tree is of type A then the nodes and leafs will also be type A as well.
The code shown below is the correct answer and i had a hard time understanding (i.e. how does type B come about) and how to decide when to declare different type (i.e. A,B,C, …etc). Any help would be appreciated, thank you.

sealed trait Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B
}

final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B =
node(left.fold(node, leaf), right.fold(node, leaf))
}

final case class Leaf[A](value: A) extends Tree[A]
{
def fold[B](node: (B, B) => B, leaf: A => B): B =
leaf(value)
}

The basic idea for a fold is to transform some type (in this case, Tree[A]) into another, user-chosen type. We’ll arbitrarily call this user-chosen type B. Every subtype in the Tree hierarchy needs a way to be transformed into a B. These various transformations will be the parameters to our eventual fold method.

Now our tree datastructure is recursively defined, so our fold is naturally going to be recursively defined as well.

The base case (Leaf) is pretty straightforward. For example (hypothetical):

trait ExampleLeaf[A] {
 def value: A
 def foldLeaf[B](f: (A)=>B): B = f(value)
}

that is, turn the contents of the Leaf node into a value of whatever type the user wishes to transform the tree into.

The Nodes encode the recursive portion of the datastructure. They don’t have a value of their own, but they do have two children, each of which can be recursed over. So the basic way a Node will be transformed into a B is by first transforming each of its children into Bs, then combining them some how. We know how to transform the children (that’s the foldLeaf example, above, if they were Leafs, else it’s recursively the foldNode.). We just need to know how to combine two Bs into one.

trait ExampleNode[A] {
  // These could actually be `Leaf` or `Node`, but i'm glossing over that a bit now for simplicity
 def left: ExampleLeaf
 def right: ExampleLeaf

 def foldNode[B](f: (B, B) => B) = {
  // Combine the results of the two children
  // Oops!  We need the parameter to `foldLeaf` here, right?
  f(left.foldLeaf(???), right.foldLeaf(???))
 }
}

So we’ll “fix” the type signature now by combining the transformation types for each of {Node, Leaf} into a single fold method:

trait Tree[A] {
 def fold[B](node: (B, B) => B, leaf: (A) => B))
}

final case class Leaf[A](value: A) extends Tree[A] {
 def fold[B](node: (B, B) => B, leaf: A => B): B =
  leaf(value)
}

final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A] {
 def fold[B](node: (B, B) => B, leaf: A => B): B = 
  node(left.fold(node, leaf), right.fold(node, leaf))
}

Hope this helps explain the derivation a bit.

1 Like

You can define fold as

def fold(node: (A, A) => A, leaf: A => A): A

But when you implement it, you’ll see that only the input parameter of leaf depends on type parameter A. That’s a clue that you can make the type signature of fold more general by introducing a type parameter B.

1 Like