`withDefaultValue` on a TreeMap makes it slow

I was solving one the popular leetcode questions: Find Median from Data Stream. Here is my attempt at it which causes memory limit error on leetcode’s online evaluator. It runs noticeably slowly on local machine:

    class MedianFinder:
        import scala.collection.mutable.TreeMap

        private val leftCounts = TreeMap[Int, Int]() withDefaultValue 0
        private var leftCumSize = 0
        private val rightCounts = TreeMap[Int, Int]() withDefaultValue 0
        private var rightCumSize = 0

        def addNum(num: Int) =
            if leftCumSize == 0 || num <= leftCounts.lastKey then
                leftCounts(num) += 1
                leftCumSize += 1
            else if num > leftCounts.lastKey then
                rightCounts(num) += 1
                rightCumSize += 1
            end if

            if leftCumSize == rightCumSize + 2 then
                val (maxLeft, maxLeftCount) = leftCounts.last
                require(maxLeftCount > 0)
                leftCounts(maxLeft) -= 1
                leftCumSize -= 1
                if maxLeftCount == 1 then leftCounts.remove(maxLeft)

                rightCounts(maxLeft) += 1
                rightCumSize += 1
            else if leftCumSize + 1 == rightCumSize then
                val (minRight, minRightCount) = rightCounts.head
                require(minRightCount > 0)
                rightCounts(minRight) -= 1
                rightCumSize -= 1
                if minRightCount == 1 then rightCounts.remove(minRight)

                leftCounts(minRight) += 1
                leftCumSize += 1
            else require(leftCumSize == rightCumSize + 1 || leftCumSize == rightCumSize)
            end if

            require(leftCumSize == rightCumSize + 1 || leftCumSize == rightCumSize)
        end addNum

        def findMedian(): Double =
            if leftCumSize == rightCumSize then
                if leftCumSize != 0 then (leftCounts.lastKey + rightCounts.firstKey) / 2.0
                else 0
            else if leftCumSize == rightCumSize + 1 then
                leftCounts.lastKey.toDouble
            else ???
            end if
        end findMedian

    end MedianFinder

I didn’t want to change structure of my solution, as it is good enough, but didn’t want to spend time profiling it for which intermediate objects are making it slow. So I turned to Gemini to ask and it returned the answer that I should remove withDefaultValue from my instantiation of TreeMap and instead use .getOrElse(num, 0). Surprisingly it solves the problem!

I checked the source of SortedMap.WithDefault and it looks like this:

    override def subtractOne(elem: K): WithDefault.this.type = { underlying.subtractOne(elem); this }

    override def addOne(elem: (K, V)): WithDefault.this.type = { underlying.addOne(elem); this }

and

    def get(key: K): Option[V] = underlying.get(key)

There are no extra object creations here, the methods are called on underlying map. Why should withDefaultValue be so slow?

2 Likes