Translating Haskell Expression Problem Solution to Scala 3

Here’s my solution using type classes only:

case class Const(c: Int)
case class Add[A, B](l: A, r: B)
case class Neg[A](a: A)

trait Expr[A]
given Expr[Const] with { }
given AddExpr[A, B](using leftExpr: Expr[A], rightExpr: Expr[B]): Expr[Add[A, B]] with { }
given NegExpr[A](using expr: Expr[A]): Expr[Neg[A]] with { }

trait Eval[A] {
    def eval(a: A)(using expr: Expr[A]): Int
}
given ConstEval: Eval[Const] with {
    def eval(a: Const)(using expr: Expr[Const]) = a.c
}
given AddEval[A, B](using leftExpr: Expr[A], rightExpr: Expr[B], leftEval: Eval[A], rightEval: Eval[B]): Eval[Add[A, B]] with {
    def eval(a: Add[A, B])(using expr: Expr[Add[A, B]]) = leftEval.eval(a.l) + rightEval.eval(a.r)
}
given NegEval[A](using expr: Expr[A], subEval: Eval[A]): Eval[Neg[A]] with {
    def eval(a: Neg[A])(using expr: Expr[Neg[A]]) = -subEval.eval(a.a)
}

val four = Const(4)
val twoPlusThree = Add(Const(2), Const(3))
val twoPlusThreeNegated = Neg(twoPlusThree)

def eval[A](a: A)(using expr: Expr[A], eval: Eval[A]) = eval.eval(a)
println(eval(four))
println(eval(twoPlusThree))
println(eval(twoPlusThreeNegated))

(I left out Stringify because it can be done like Eval.)

That’s a hard read but the code can be simplified by merging Expr and Eval:

case class Const(c: Int)
case class Add[A, B](l: A, r: B)
case class Neg[A](a: A)
  
trait Expr[A] {
    def eval(a: A): Int
}
given Expr[Const] with {
    override def eval(a: Const) = a.c
}
given AddExpr[A, B](using leftExpr: Expr[A], rightExpr: Expr[B]): Expr[Add[A, B]] with {
    override def eval(a: Add[A, B]) = leftExpr.eval(a.l) + rightExpr.eval(a.r)
}
given NegExpr[A](using expr: Expr[A]): Expr[Neg[A]] with {
    override def eval(a: Neg[A]) = -expr.eval(a.a)
}
  
val four = Const(4)
val twoPlusThree = Add(Const(2), Const(3))
val twoPlusThreeNegated = Neg(twoPlusThree)
  
def eval[A](a: A)(using expr: Expr[A]) = expr.eval(a)
println(eval(four))
println(eval(twoPlusThree))
println(eval(twoPlusThreeNegated))

If the modularity provided by type classes is not important, then subclassing with type boundaries will do the job:

trait Expr {
    def eval: Int
}

case class Const(c: Int) extends Expr {
    def eval = c
}
case class Add[A <: Expr, B <: Expr](l: A, r: B) extends Expr {
    override def eval = l.eval + r.eval
}
case class Neg[A <: Expr](a: A) extends Expr {
    override def eval = -a.eval
}

val four = Const(4)
val twoPlusThree = Add(Const(2), Const(3))
val twoPlusThreeNegated = Neg(twoPlusThree)
println(four.eval)
println(twoPlusThree.eval)
println(twoPlusThreeNegated.eval)

In all three approaches, serialization (Stringify) can be implemented by implementing toString for the case classes.

2 Likes