Using flatMap to chain conditional operations

If you use IntelliJ IDEA, there’s a “desugar for comprehension” action that shows the map/flatMap equivalent of a for expression. It can help figure out what’s really going on.

1 Like

To me, for comprehensions are black magic

Here is an in-depth explanation of the de-sugaring of for comprehensions by Odersky et al.:
https://www.artima.com/pins1ed/for-expressions-revisited.html#23.4

1 Like

Thank you, I will look into that!

I believe my problem with understanding for comprehensions is that I don’t get the naming. Neither do I understand why flatMap is called flatMap, nor why for comprehensions are called like this.

I mean, I get the map part, but what is flat? Flat in what sense?

And regarding the other: What does for have to do with all that? I understand the for as in for loop and the for in forEach but I don’t get what those all have to do with for comprehensions.

flatMap is called that way because it first maps and then flattens

scala> assert( Some(Some(3)).flatten == Some(3) )

scala> assert( "3".toIntOption == Some(3) )

scala> assert( Some("3").map(_.toIntOption) == Some(Some(3)) )

scala> assert( Some("3").map(_.toIntOption).flatten == Some(3) )

scala> assert( Some("3").flatMap(_.toIntOption) == Some(3) )
2 Likes

I mean, I get the map part, but what is flat? Flat in what sense?

Let’s start with map.
The map function has the following signature:

def map[F[_], A, B](fa: F[A])(f: A => B): F[B]

Which we can read like:
For a given effect F (also called contexts, or containers), if we have some effectual value fa (or type F of A, F[A]) and a function from A to B, A => B. It will return a new effectual value of type B, F[B].

So, the idea of the map function is to apply normal (also called plain) function to an effectual value.
In other words, the map function allows us to forget about managing the effect and just focus on the values. The function will take care of the unwrapping and wrapping that has to be done.

So a typical example would be the following.
We will use on the most basic effects, the Option (which represents the possibility of the absence of a value), to model a safe division.

def safeDivision(x: Int, y: Int): Option[Int] =
  if (y != 0) Some(x / y) else None

Now, if we have to compute the following arithmetic expression: y = a + b / c
We could do something like this:

def foo(a: Int, b: Int, c: Int): Option[Int] =
  safeDivision(b, c) match {
    case Some(temp) => Some(temp + a)
    case None => None
  }

However, look at all that boilerplate; we have to manually unwrap and wrap again all the time.
A more complex expression would be a nightmare and we could easily make mistakes.
Enter map

def map[A, B](oa: Option[A])(f: A => B): Option[B] = oa match {
  case Some(a) => Some(f(a))
  case None => None
}

def foo(a: Int, b: Int, c: Int): Option[Int] =
  map(safeDivision(b, c))(temp => temp + a)

Great!
Now, what happens if we want to compute this new expression: y = (a / b) / c
Well, we could try to do the same as before…

def bar(a: Int, b: Int, c: Int): Option[Option[Int]] =
  map(safeDivision(a, b))(temp => safeDivision(temp, c))

Which we may say it works… but having to handle that nesting is not really nice.
And again a complex expression would result in a very nested structure.
Also, we actually do not care if the first division failed or the second, we just care that the complete expression failed.

So, it is time to meet another helper function, flatten.

def flatten[F[_], A](ffa: F[F[A]]): F[A]

Which again, reads as follows:
For a nested effectual value, return a no-nested effectual value. Which we may call that process a flattening over the value; we flatten the two Fs into one F.

Applying that to our previous function we get the following:

def flatten(ooa: Option[Option[A]]): Option[A] = ooa match {
  case Some(Some(a)) => Some(a)
  case Some(None) => None
  case None => None
}

def bar(a: Int, b: Int, c: Int): Option[Int] =
  flatten(map(safeDivision(a, b))(temp => safeDivision(temp, c)))

But, as you may have already guessed that process of mapping and then flattening is pretty common, so we may create a helper flatMap:

def flatMap[F[_], A, B](fa: F[A])(f: A => F[B]): F[B] =
  flatten(map(fa)(f))

Which we can read as:
Given an effectual value and a function that returns a new effectual value, map the function and the flatten the result.

So we can again refactor our previous example as:

def bar(a: Int, b: Int, c: Int): Option[Int] =
  flatMap(safeDivision(a, b))(temp => safeDivision(temp, c)))

  // Or with normal method like syntax
  safeDivision(a, b).flatMap(temp => safeDivision(temp, c))

  // Or with for:
  for {
    temp <- safeDivision(a, b)
    result <- safeDivision(temp, c)
  } yield result

Hope this helps :slight_smile:

3 Likes

TIL :+1:

First of all: Thank you so much for your detailed explanations, they really help to understand what’s happening under the hood!

But there is things I (still) cannot follow:

What map function does have this signature? If I look at Either’s map function I see the following signature:

def map[B1](f: B => B1): Either[A, B1] 

Is this equal to the signature you mentioned? Plus: Until now I believed that map does nothing else than applying a function to a value to transform it into another value. So there is more to map than I thought?

Yeah so.

map as a concept is very general, it comes from category theory and all that.
What you see in Scala is a (simplified?) implementation of such concept.

So, given Scala is a mix between FP and OOP, the creators of the language decided to model that function as a method on some classes, like Either, Option, List, etc.

Actually, y definition is also a simplification, since map comes from Functors which I didn’t mentioned.

BTW, if you want to learn more about FP, you can check my comment here.

So TL;DR;
Yes is the same function, Either becomes the F[_].

@dubaut one thing you can do, for educational purposes, to help you see the connection between the map function mentioned by @BalmungSan and the map functions you see in the Scala standard library, is to define a Functor abstraction, which is characterised by the fact that it has a map function, and then instantiate this Functor abstraction for different type constructors:

trait Functor[F[_]] {
  def map[A,B](fa: F[A])(f: A => B): F[B]
}

val listFunctor: Functor[List] = new Functor[List] {
  def map[A,B](fa: List[A])(f: A => B): List[B] = fa map f
}

val optionFunctor: Functor[Option] = new Functor[Option] {
  def map[A,B](fa: Option[A])(f: A => B): Option[B] = fa map f
}

val tryFunctor: Functor[Try] = new Functor[Try] {
  def map[A,B](fa: Try[A])(f: A => B): Try[B] = fa map f
}

val futureFunctor: Functor[Future] = new Functor[Future] {
  def map[A,B](fa: Future[A])(f: A => B): Future[B] = fa map f
}

assert(   listFunctor.map(List(1,2,3))(_ + 1) == List(2,3,4) )

assert(     optionFunctor.map(Some(3))(_ + 1) == Some(4) )

assert( tryFunctor.map(Try{"3".toInt})(_ + 1) == Success(4) )

assert( Await.result(
          futureFunctor.map(
                 Future.successful(3))(_ + 1),
          Duration.Inf)                       == 4 )

Here are the imports:

import scala.concurrent.{Await,Future}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.util.{Try,Success}
1 Like

@dubaut you may find this useful (if you can stomach the garish colours :smile: ): https://www.slideshare.net/pjschwarz/functors

Also, in order for a function to qualify as a map function, it is not sufficient for it to have the right signature, it must also obey the functor laws: https://www.slideshare.net/pjschwarz/functor-laws

1 Like

I have one more question: How does the for-comprehension know that it should result in an Either, since yield only produces the thing which is supposed in the Right if everything goes well?

Remember for is just sugar syntax for flatMap and map.
As such, if you start with an Either the result is an Either, you can see that none of the function we discussed changed the effect type (the F[_]).

So, the one that actually knows how to return a right if everything went fine is the implementation of flatMap on Either.

As it says e.g. in https://typelevel.org/cats/typeclasses/functor.html#a-different-view, another way of viewing a Functor[F] is that F allows the lifting of a pure function A => B into the effectful function F[A] => F[B].

Let’s add the lift function to the Functor abstraction we saw earlier:

trait Functor[F[_]] {
  def map[A,B](fa: F[A])(f: A => B): F[B]
  
  def lift[A, B](f: A => B): F[A] => F[B] =
    fa => map(fa)(f) 
}

Here are some tests:

val square: Int => Int = x => x * x

assert(    listFunctor.lift(square)(List(1,2,3)) == List(1,4,9) )
assert(      optionFunctor.lift(square)(Some(3)) == Some(9) )
assert(  tryFunctor.lift(square)(Try{"3".toInt}) == Success(9) )
assert( Await.result(
          futureFunctor.lift(square)
            (Future.successful(3)),
          Duration.Inf)                          == 9 )

I think maybe playing with this simpler approach to understanding flatMap as in the REPL session below might help:

scala> Vector(1, 2, 3).map(x => x + 1)
val res0: Vector[Int] = Vector(2, 3, 4)

scala> Vector(1, 2, 3).map(x => Vector(x + 1))
val res2: Vector[Vector[Int]] = Vector(Vector(2), Vector(3), Vector(4))

scala> Vector(1, 2, 3).map(x => Vector(x + 1)).flatten
val res3: Vector[Int] = Vector(2, 3, 4)

scala> Vector(1, 2, 3).flatMap(x => Vector(x + 1))
val res4: Vector[Int] = Vector(2, 3, 4)

So you can think of flatMap as first map then flatten, and flatten will remove one level of nesting. You often end up with nesting, e.g. when working with Option-wrapped values in a collection.

(My experience with beginners is that its not necessary with all the Functor/Monad etc lingo when starting to grasp flatMap)

2 Likes

Seconded. Feel free to consider that stuff an interesting path you might choose to go down, but don’t need to go down in order to grasp for and flatMap.

4 Likes

Absolutely. When it comes to understanding a concept, I find that it is usually much easier to start with specific examples before trying to grasp a general abstraction such as flatMap.

I would add to your example the fact that the main usefulness of flatMap is that the inner Vectors can be of different sizes. For example

scala> Vector(Vector(1), Vector(2,3)).flatten
res0: Vector[Int] = Vector(1, 2, 3)

Finally, I would point out that Option is essentially just a Vector that can only have zero or one element.

Yes, good points!

@Russ @bjornregnell hi, yes, I agree, that’s why my answer to the question of what flatMap and flatten are was the following, aimed at illustrating their behaviour with as simple an example as I could think of:

scala> assert( Some(Some(3)).flatten == Some(3) )
scala> assert( "3".toIntOption == Some(3) )
scala> assert( Some("3").map(_.toIntOption) == Some(Some(3)) )
scala> assert( Some("3").map(_.toIntOption).flatten == Some(3) )
scala> assert( Some("3").flatMap(_.toIntOption) == Some(3) )

I like the example with Option because Option is very simple. Seeing an example with List (or Vector) is also important to begin to grasp the idea that each monad is very different and while they all have a flatMap function, it is so different in each monad, it is where the smarts/dna/uniqueness of each monad lives.

I found it very useful, early on, to compare the implementation of the flatMap function of different monads, e.g. Option and List below:

trait Functor[F[_]] {
  def map[A,B](ma: F[A])(f: A => B): F[B]
}

trait Monad[F[_]] extends Functor[F] {
  def unit[A](a: => A): F[A]  
  def flatMap[A,B](ma: F[A])(f: A => F[B]): F[B]
  def map[A,B](ma: F[A])(f: A => B): F[B] = 
    flatMap(ma)(a => unit(f(a)))  
  def flatten[A](ma: F[F[A]]): F[A] =
    flatMap(ma)(identity)
}

val optionMonad = new Monad[Option] {  
  def unit[A](a: => A): Option[A] = Some(a)  
  def flatMap[A,B](ma: Option[A])(f: A => Option[B]): Option[B] = ma match {
    case Some(a) => f(a)
    case None => None
  }
}

val listMonad = new Monad[List] {
  def unit[A](a: => A): List[A] = List(a)  
  def flatMap[A,B](ma: List[A])(f: A => List[B]): List[B] = ma match {
    case Nil => Nil
    case head::tail => f(head) ::: flatMap(tail)(f)
  }
}

assert(   optionMonad.map(Some(3))(_ + 1) == Some(4) )
assert( listMonad.map(List(1,2,3))(_ + 1) == List(2,3,4) )

assert( optionMonad.flatMap(Some("3"))(_.toIntOption) == Some(3) )
assert( listMonad.flatMap(List(1,0,4)){case 0 => List[Int]() case x => List(x,x+1,x+2)} == List(1,2,3,4,5,6) )

assert( optionMonad.flatten(Some(Some(3))) == Some(3) )
assert( listMonad.flatten(List(List(1), List(2,3), List(4,5,6))) == List(1,2,3,4,5,6) )

one irony though is that the flatMap and flatten methods found in Scala are more complicated than the ones provided by the monad abstraction because they are more powerful:

That slide is from https://www.slideshare.net/pjschwarz/scala-collection-methods-flatmap-and-flatten-are-more-powerful-than-monadic-flatmap-and-flatten (note: not updated for Scala 2.13).

1 Like

Once again thank you for all your elaborations! Now I have content worth a week of study and I will go through all of the resources provided here!

2 Likes