After a lot of screwing around I finally got this to work. Next question: How do I make == work symmetrically so I can write
3 == Complex(3, 0)
Is that possible?
After a lot of screwing around I finally got this to work. Next question: How do I make == work symmetrically so I can write
3 == Complex(3, 0)
Is that possible?
You won’t be able to change the ==
operator, as implicit conversions will only apply for method/operator signatures not yet defined and ==
takes Any
, so it is always available. What you can do is use a different operator, e.g. ===
and then have a typeclass that says if things can be compared with it (like the Eq
typeclass in cats). Or you could use ==:
, which would then be called on the Complex
object, but that would mean using different operators based on if the Complex
is on the left or right.
Here is my Complex number class in case anyone is interested. A cubic polynomial solution (analytical, not numerical) is included in the companion object. (128 lines total)
package scalar_
import types_._
import mathx_._
import scala.language.implicitConversions
case class Complex(real: Scalar, imag: Scalar=0) {
override def toString: Text =
s"%g ${if (imag<0) "-" else "+"} %gi".form(real, abs(imag))
lazy val mag = hypot(real, imag) // magnitude
lazy val dir = atan2(imag, real) // direction
def abs_imag = abs(imag)
def + (c: Complex) = Complex(real + c.real, imag + c.imag)
def + (s: Scalar) = Complex(real + s, imag)
def - (c: Complex) = Complex(real - c.real, imag - c.imag)
def - (s: Scalar) = Complex(real - s, imag)
def * (c: Complex) = Complex(real * c.real - imag * c.imag,
real * c.imag + imag * c.real)
def * (s: Scalar) = Complex(s * real, s * imag)
def / (c: Complex): Complex = {
val denom = sqr(c.real) + sqr(c.imag)
val r1 = real * c.real + imag * c.imag
val i1 = imag * c.real - real * c.imag
Complex(r1, i1) / denom
}
def / (s: Scalar) = Complex(real / s, imag / s)
def unary_- = Complex(-real, -imag)
def conjugate = Complex(real, -imag)
def pow(exp: Real): Complex = {
val mag1 = scalar_.pow(mag, exp)
val dir1 = exp * dir
Complex(cos(dir1), sin(dir1)) * mag1
}
override def equals(that: Any): Bool = that match {
case c: Complex => real == c.real && imag == c.imag
case s: Scalar => real == s && imag == 0
case r: Real => real == r && imag == 0 // ignore compiler warning
case i: Int => real == i && imag == 0
case _ => false
}
def ==: (a: Any) = equals(a) // allows reverse order of args for == test
def toScalar: Scalar = {
if (imag == 0) return real
throw new RuntimeException(s"Cannot convert $this to Scalar")
}
def toReal = Real(toScalar)
}
object Complex {
def sqr(a: Complex) = a * a
def cube(a: Complex) = a * a * a
def sqrt(a: Complex): Complex = a.pow(0.5)
def sqrt(a: Real): Complex = sqrt(Complex(a))
def sqrt(a: Int): Complex = sqrt(Complex(a))
def cbrt(a: Complex): Complex = a.pow(1/3.0)
def cbrt(a: Real): Complex = cbrt(Complex(a))
def cbrt(a: Int): Complex = cbrt(Complex(a))
def pow(c: Complex, exp: Real): Complex = c.pow(exp)
def toScalar(c: Complex) = c.toScalar
def toReal(c: Complex) = c.toReal
def ==: (a: Any, b: Complex) = a match {
case a: Complex => a == b
case a: Scalar => Complex(a) == b
case a: Int => Complex(a) == b
case _ => false
}
implicit def ScalarToComplex(r: Scalar): Complex = Complex(r)
def quadraticRoots(a: Scalar, b: Scalar, c: Scalar): Vector[Complex] = {
// roots of a quadratic polynomial ax^2 + bx + c = 0
if (a == 0) return Vector(-c/b)
val d = Complex.sqrt(b*b - 4*a*c)
val x1 = (-b + d) / a / 2
val x2 = (-b - d) / a / 2
Vector(x1, x2)
}
def cubicRoots(a: Scalar, b: Scalar, c: Scalar, d: Scalar): Vector[Complex] = {
// roots of a cubic polynomial ax^3 + bx^2 + cx + d = 0 based on
// https://en.wikipedia.org/wiki/Cubic_equation#General_cubic_formula
if (a == 0) return quadraticRoots(b, c, d)
val d0 = b*b - 3*a*c
val d1 = 2*b*b*b - 9*a*b*c + 27*a*a*d
val g = (Complex.sqrt(-3) - 1) / 2
val z = Complex.sqrt(d1*d1 - 4*d0*d0*d0)
val C1 = Complex.cbrt((d1 + z) / 2)
val C2 = C1 * g
val C3 = C2 * g
val x1 = -(b + C1 + d0/C1) / a / 3
val x2 = -(b + C2 + d0/C2) / a / 3
val x3 = -(b + C3 + d0/C3) / a / 3
val Vector(r1, r2, r3) = Vector(x1, x2, x3).sortBy(_.abs_imag)
Vector(r1.copy(imag=0), r2, r3) // avoid roundoff error for known real root
}
// The following trick avoids conflicts with Scalar implicit conversions:
trait implicits1 { implicit def RealToComplex(r: Real) = Complex(r) }
object implicits extends implicits1
// to activate this implicit conversion, add the following line:
// import Complex.implicits._
}