Generating Pythagorean triples in Scala

Nowadays there is something of a ‘challenge’ going around in performant programming language circles about whether they can implement an easy-to-read Pythagorean triple generator using (infinite) ranges. E.g., see https://atilanevesoncode.wordpress.com/2018/12/31/comparing-pythagorean-triples-in-c-d-and-rust/ . I found this coincidentally amusing because I recently submitted a PR to the Scala documentation site that uses the Pythagorean triple generation as an example to show the power and elegance of for-comprehensions: https://github.com/scala/docs.scala-lang/pull/1235

So naturally I tried converting my example to use an infinite range (concretely, an Iterator):

def naturals = Iterator.from(1)

def pythagoreanTriples = for {
  a <- naturals
  b <- naturals if b > a
  c <- naturals if a * a + b * b == c * c
} yield (a, b, c)

pythagoreanTriples take 2 foreach println

However, this generator seems to get stuck indefinitely. I tried to rewrite it with slightly more precise functions:

def naturals = Iterator.from(1)

def pythagoreanTriples = naturals.flatMap { a =>
  naturals.dropWhile(a.>=).flatMap { b =>
    naturals.collect {
      case c if c * c == a * a + b * b => (a, b, c)
    }
  }
}

pythagoreanTriples take 2 foreach println

However, this generator suffers the same problem–stuck indefinitely.

Can anyone shed light on why the iterators are seemingly not iterating?

If you use the same iterator, if a value (say 1) is used for a I don’t think it will be used for b. As far as I can see, there are two ways around this.

  • Use three different iterators, with the same definition
  • Use a stream

regards,
Siddhartha

Thank you. I believe I am using three different iterators at runtime (they are generated by the naturals method calls). I’ll keep investigating further as to where the freeze is happening.

Your code loops over Ints indefinitely for c, while a=1 and b=1, so it is not finding any triples.

I’m not sure what you mean by short-circuiting traversal and how it will fix this. Perhaps you want to stop iterating after the first hit? You could call the find method for that. But it won’t help if for a given a and b there is no solution for c.

You should have only one indefinite iterator, and it should be the outer-most loop (i.e. the first argument in for). You could have an indefinite iterator for c and use 1 < a < b < c to have finite ranges for a and b.

But since you speak of performance: iterating over all three variables and filtering is, of course, extremely inefficient. It is much faster to iterate over two of them and calculate the third.

The main issue here is that Iterator.from(1) is infinite, and wraps around overflow. This is in agreement with the scaladoc, but also pretty surprising behaviour to me.

That means that Iterator.from(1).filter(_ => false).next() will loop forever.

A non-looping version, def positiveInts = Iterator.from(1).takeWhile(i => i + 1 > 0) should terminate. Though I can’t check atm, since I’m on my phone.

It terminates, but you will find (1,4,869476073).

scala> 1 * 1 + 4 * 4
res2: Int = 17

scala> 869476073 * 869476073
res3: Int = 17

def naturals = Iterator.from(1).takeWhile(i => i * i / i == i) would work, but still take a looong time. And I guess the challenge is to implement an algorithm that can work for unbounded numbers like BigInt.

What is the goal exactly? What’s wrong with:

val p = for {
  b <- Iterator.from(1)
  a <- 1 until b
} yield (a, b, a * a + b * b)

We’re looking for a c equal to math.sqrt(a * a + b * b) where c is an integer. So the following would be a relatively efficient implementation, disregarding integer overflow:

def pythagoreanTriples = (for {
  b <- Iterator.from(1)
  a <- 1 until b
  c2 = a * a + b * b
} yield (a, b, math.sqrt(c2).toInt, c2)).collect{ case (a,b,c,c2) if c*c == c2 => (a, b, c) }

Got you. I forgot what Pythagorean triples were. Could also be written as:

val p = for {
  b <- Iterator.from(1)
  a <- 1 until b
  c2 = a * a + b * b
  c = Math.sqrt(c2).toInt
  if c * c == c2
} yield (a, b, c)

Yes indeed :sweat_smile: I was just about to edit my post.

Sure, I intended to answer OPs bafflement though, not do their exercises for them.

Thank you! The main breakthrough for me was your suggestion that there should be only one infinite iteration. And also @charpov 's example, but I felt that I could avoid the square root calculation. This works nicely:

def pythagoreanTriples = for {
  c <- Iterator.from(1)
  b <- 1 until c
  a <- 1 until b if c * c == a * a + b * b
} yield (a, b, c)

Just to clarify, I spoke about ‘performant programming language circles’ but I’m not necessarily looking to performance tune myself :slight_smile: I just want to avoid ‘accidentally quadratic’, which your hint achieved!

scala> pythagoreanTriples take 3 foreach println
(3,4,5)
(6,8,10)
(5,12,13)

Edit: I’d also like to say I achieved my original goal, which was to demonstrate (even if just to myself) that Scala’s expressive power is light-years ahead of several other languages even if they are more performant, etc.

1 Like

Yes, that’s the way. The beauty of Scala is that this is much more readable than the equivalent:

def pythagoreanTriples =
  Iterator.from(1).flatMap { c =>
  (1 until c).flatMap { b =>
    (1 until b).withFilter(a => c * c == a * a + b * b).map(a => (a, b, c))
  }
}
2 Likes