# Need help on understanding generic fold for generic data

Hi all, I just started learning Scala and FP through Essential-Scala ebook that is open source by underscore.io. I’m currently learning about generic, functions and fold method. I stumble upon a question to code
"A Tree of type A is a Node with left and right Tree or a Leaf with an element of type A. Implement this algebraic data type along with a fold method."

My answer for the fold method was totally wrong, I put the type to be A for everything (i.e. no type B written at all) because i thought if tree is of type A then the nodes and leafs will also be type A as well.
The code shown below is the correct answer and i had a hard time understanding (i.e. how does type B come about) and how to decide when to declare different type (i.e. A,B,C, …etc). Any help would be appreciated, thank you.

sealed trait Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B
}

final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B =
node(left.fold(node, leaf), right.fold(node, leaf))
}

final case class Leaf[A](value: A) extends Tree[A]
{
def fold[B](node: (B, B) => B, leaf: A => B): B =
leaf(value)
}

The basic idea for a fold is to transform some type (in this case, `Tree[A]`) into another, user-chosen type. We’ll arbitrarily call this user-chosen type `B`. Every subtype in the `Tree` hierarchy needs a way to be transformed into a `B`. These various transformations will be the parameters to our eventual `fold` method.

Now our tree datastructure is recursively defined, so our fold is naturally going to be recursively defined as well.

The base case (`Leaf`) is pretty straightforward. For example (hypothetical):

``````trait ExampleLeaf[A] {
def value: A
def foldLeaf[B](f: (A)=>B): B = f(value)
}
``````

that is, turn the contents of the `Leaf` node into a value of whatever type the user wishes to transform the tree into.

The `Node`s encode the recursive portion of the datastructure. They don’t have a value of their own, but they do have two children, each of which can be recursed over. So the basic way a Node will be transformed into a `B` is by first transforming each of its children into `B`s, then combining them some how. We know how to transform the children (that’s the `foldLeaf` example, above, if they were `Leaf`s, else it’s recursively the `foldNode`.). We just need to know how to combine two `B`s into one.

``````trait ExampleNode[A] {
// These could actually be `Leaf` or `Node`, but i'm glossing over that a bit now for simplicity
def left: ExampleLeaf
def right: ExampleLeaf

def foldNode[B](f: (B, B) => B) = {
// Combine the results of the two children
// Oops!  We need the parameter to `foldLeaf` here, right?
f(left.foldLeaf(???), right.foldLeaf(???))
}
}
``````

So we’ll “fix” the type signature now by combining the transformation types for each of {`Node`, `Leaf`} into a single `fold` method:

``````trait Tree[A] {
def fold[B](node: (B, B) => B, leaf: (A) => B))
}

final case class Leaf[A](value: A) extends Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B =
leaf(value)
}

final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A] {
def fold[B](node: (B, B) => B, leaf: A => B): B =
node(left.fold(node, leaf), right.fold(node, leaf))
}
``````

Hope this helps explain the derivation a bit.

1 Like

You can define `fold` as

``````def fold(node: (A, A) => A, leaf: A => A): A
``````

But when you implement it, you’ll see that only the input parameter of `leaf` depends on type parameter `A`. That’s a clue that you can make the type signature of `fold` more general by introducing a type parameter `B`.

1 Like