Implementing map-fusion with dotty macros

I’m one of the authors of scalding and several similar libraries for building distributed computation pipelines. One thing those libraries generally can’t do is real map-fusion of the code:

pipe.map { x => x * 2 }.map { x => x + 2 }

we might like to avoid boxing twice and have that optimized to

pipe.map { x => (x * 2) + 2 }

one way I had thought we could solve this was with dotty’s ability to do staging and hold expressions.

Imagine a toy like:

object PipeExample {
  sealed trait Pipe[A] {
    def toExpr(using QuoteContext): Expr[List[A]]
  }
  case class Source[A](toList: List[A]) extends Pipe[A] {
    def toExpr(using QuoteContext) = ???
  }
  case class Mapped[A, B](first: Pipe[A], fn: Expr[A => B]) extends Pipe[B] {
    def toExpr(using QuoteContext) = ???
    // here we would recursively expand all the functions into a single method so the JVM doesn't see them as megamorphic calls to Function1.apply.
  }

  extension pipeOps {
    inline def [A, B](first: Pipe[A]).map(inline fn: => A => B)(using QuoteContext): Pipe[B] =
      Mapped(first, '{fn})
  }

  inline def compile[A](p: => Pipe[A])(using Toolbox): List[A] =
    run(p.toExpr)
}

this won’t compile because you get:

[error] 204 |      Mapped(first, '{fn})
[error]     |                      ^^
[error]     |                      access to value fn from wrong staging level:
[error]     |                       - the definition is at level 0,
[error]     |                       - but the access is at level 1.

A similar issue will come up with this pattern without map fusion. If we build up nodes in a dag that will later build Exprs, we generally want to tie the type of the nodes to the type of the expressions they generate, but every time I try to do that, I get the error:

[error]     |       access to type A from wrong staging level:
[error]     |        - the definition is at level 0,
[error]     |        - but the access is at level 1.
[error]     |
[error]     |        The access would be accepted with a given scala.quoted.Type[A]

I bypassed this error by just using Any and then casting to the correct type. I think my casts are sound, but I don’t like having to cast (even if users wouldn’t have to interact with the casts).

My questions:

  1. can someone show a toy example without casts that achieves the map-fusion optimization (which is just an example not the entirety of what one would want)?
  2. are there real soundness bugs these limits about level 0 and level 1 errors are complaining about, or is it (as I suspect) ruling out many sound programs in order to also rule out many unsound programs?
  3. If the above is true, is there a prospect for a more precise static check that could allow us to express this pattern without resorting to casts?

Thanks.

2 Likes

I pinged https://gitter.im/lampepfl/dotty alerting them to this question.

Cool idea! This deserves a closer look, but just to mention the obvious, what happens if you follow this suggestion:

Rationale: since types are erased, to manipulate type argument A you need a value representing A — so, a Type[A].

The other error is more substantial, I’ll get back to you in a bit.

As an aside: in practice, I suspect this thread still belongs to contributors.

Oversimplifying, level errors on terms are never safe: the Dotty type system is supposed to keep terms and quoted terms separated, and those errors complain about a mixup between the two. Oversimplifying even more, Dotty is complaining about mixing Foo and Expr[Foo], but the formal details are more complicated… There are docs, but I think they could be clearer.

I think you’re right to use staging, but I’m not sure that mixes well with inline methods — those combine better with macros. Here’s a prototype using macros that should work (ignoring bugs with quasiquotes):

import scala.quoted._

//import scala.quoted.staging._
object PipeExample {
  sealed trait Pipe[A] { def toList: List[A] }
  case class Source[A](toList: List[A]) extends Pipe[A]
  case class Mapped[A, B](first: Pipe[A], fn: A => B) extends Pipe[B] {
    def toList = first.toList.map(fn)
  }

  def mapImpl[A: Type, B: Type] (first: Expr[Pipe[A]], fn: Expr[A => B])(using qctx: QuoteContext): Expr[Pipe[B]] = {
    first match {
      case '{ Mapped($iFirst, $iFn) } =>
        '{ Mapped($iFirst, $iFn andThen $fn) }
      case _ =>
        '{ Mapped($first, $fn) }
    }
  }

  extension pipeOps {
    inline def [A, B](inline first: => Pipe[A]).map(inline fn: => A => B): Pipe[B] =
      ${ mapImpl('{ first }, '{ fn }) }
  }
}

in a second file (this must be separately compiled since it uses macros, just like in Scala 2 — a single project seems to work):

import PipeExample._

object Main {
  def main(args: Array[String]): Unit = {

    val foo = Source(List(1, 2, 3)).map(x => x + 1).map(x => x + 2)
    println(foo)
    println(foo.toList)
  }
}

This has a couple issues:

  • from your comments, you probably want something more advanced than andThen — I’m just fusing the map calls to avoid building an intermediate collection
  • it compiles, but the quasiquote match against case '{ Mapped($iFirst, $iFn) } => never succeeds, tho something like that should hopefully work (either now or in a future release). I managed to make it work using a lower-level API — tasty reflection; I’m cleaning up the code and I’ll post it soon.
  • this approach won’t work if your two maps are defined separately, unlike in your example — but staging would have a chance. For instance:
def bar(foo: Pipe[Int]): Pipe[Int] = foo.map(x => x + 2)
val foo = Source(List(1, 2, 3)).map(x => x + 1)

val baz = bar(foo) // fusing these map calls will likely require staging.

Fixing the match using Tasty reflection

To avoid issues with quasiquotes, we can use Tasty reflection. While this API is closer to Scala 2 macros, it is safer and better encapsulated. For simplicity I’ve specialized Mapped, but that could be fixed. With this code, map fusion works correctly. I left in some printf debugging for when you’ll need it.

import scala.quoted._

object PipeExample {
  sealed trait Pipe[A] { def toList: List[A] }
  case class Source[A](toList: List[A]) extends Pipe[A]

  import scala.tasty._
  case class Mapped(first: Pipe[Int], fn: Int => Int) extends Pipe[Int] {
    def toList = first.toList.map(fn)
  }

  def mapImpl (first: Expr[Pipe[Int]], fn: Expr[Int => Int])(using qctx: QuoteContext): Expr[Pipe[Int]] = {
    import qctx.tasty._
    def go(trm : Term): Expr[Pipe[Int]] = {
      import qctx.tasty._
      trm match {
        case Typed(t, _) =>
          go(t)
        case Inlined(_, _, t) =>
          go(t)
        case Apply(mapped, args@List(iFirst, iFn)) =>
          //  qctx.tasty.warning(s"debug4 ${mapped.showExtractors} ", mapped.pos)
          //  qctx.warning(s"debug5 ${args.map(_.showExtractors)} ${args.length}")
          '{ Mapped(${iFirst.seal.cast[Pipe[Int]]}, ${iFn.seal.cast[Int => Int]} andThen $fn) }
        case _ =>
          //qctx.tasty.warning(s"debug6 ${trm.showExtractors}", trm.pos)
          '{ Mapped($first, $fn) }
      }
    }
    go(first.unseal)
  }

  extension pipeOps {
    inline def (inline first: => Pipe[Int]).map(inline fn: => Int => Int): Pipe[Int] =
      ${ mapImpl('{ first }, '{ fn }) }
  }
}

Full version at https://github.com/Blaisorblade/dotty-map-fusion-staging-experiment.

Using staging

I’m out of time for now, but I might try the staging version later.

1 Like

I can get the most trivial macro to work, but once I want to generalize this to at least work for 1 to n chained map operations I am totally lost…

import scala.quoted._

object Example {
  inline def opt[A](inline list: List[A]) = ${opt_m('{list})}

  def opt_m[C: Type](body: Expr[List[C]])(using QuoteContext) = {
    println(body.show)
    val result = body match {
      case '{($list: List[$a]).map[$b]($fx).map[C]($fy)} =>
        '{ ${list}.map{a => ${Expr.betaReduce(fy)(Expr.betaReduce(fx)('a))} } }
    }
    println(result.show)
    result
  }
}
@main def mainMethod(): Unit = {
  val theList = List(1,2,3)
  println(Example.opt(theList.map(_ + 2).map(_ * 2.5)))
}
[info] Compiling 2 Scala sources
theList.map[scala.Int](((_$1: scala.Int) => _$1.+(2))).map[scala.Double](((_$2: scala.Int) => _$2.*(2.5)))
theList.map[scala.Double](((a: scala.Int) => {
  val x$1: scala.Int = a.+(2)
  x$1.*(2.5)
}))
[info] running mainMethod 
List(7.5, 10.0, 12.5)
1 Like

Just doing andThen isn’t what I want. You can certainly do that without a macro later as an optimization step.

What I want is to inline the call to the lambda that the user writes, but the second reply seems to get closer to it.

What about iterating the same transformation recursively?

After a retry it seems that I imagined some problems where there were none after all…

I am sorry for this code, but I quickly hacked it together. I wish I could avoid the asInstanceOf[Expr[B]] though.

import scala.quoted._

object Example {
  inline def opt[A](inline list: List[A]) = ${opt_m('{list}, '{list})}

  def opt_m[B: Type, C: Type](body: Expr[List[B]], full: Expr[List[C]])(using QuoteContext): Expr[List[C]] = {
    val result = body match {
      case '{ ($list: List[$a]).map[B]($fx) } =>
        opt_m(list, full)
      case '{ ($list: List[B]) } =>
        '{ $list.map(a => ${reduce(full, 'a)}) }
    }
    println(result.show)
    result
  }

  def reduce[B: Type, A: Type]
    (body: Expr[List[B]], init: Expr[A])
    (using QuoteContext): Expr[B] = {
    val result = body match {
      case '{ ($list: List[$a]).map[B]($fx) } =>
        val prev = reduce(list, init)
        Expr.betaReduce(fx)(prev)
      case '{ ($list: List[A]) } =>
        init.asInstanceOf[Expr[B]]
    }
    result
  }

}
@main def mainMethod(): Unit = {
  val theList = List(1,2,3)
  println(
    Example.opt(
      theList.map(_ + 2).map(_ * 2.5).map(_ / 3).map(_.toString).map(_ + "d")
    )
  )
}
theList.map[java.lang.String](((a: scala.Int) => {
  val x$4: java.lang.String = {
    val x$3: scala.Double = {
      val x$2: scala.Double = {
        val x$1: scala.Int = a.+(2)
        x$1.*(2.5)
      }
      x$2./(3)
    }
    x$3.toString()
  }
  x$4.+("d")
}))

I created PR for @Blaisorblade solution that uses scala.quoted, you can check it out here: