Minimum Cost Tree From Leaf Values

  26 Jul 2019

By Wen Xu

algorithm dynamic programming

Problem description

Given an array arr of positive integers, consider all binary trees such that: Each node has either 0 or 2 children; The values of arr correspond to the values of each leaf in an in-order traversal of the tree. (Recall that a node is a leaf if and only if it has 0 children.) The value of each non-leaf node is equal to the product of the largest leaf value in its left and right subtree respectively. Among all possible binary trees considered, return the smallest possible sum of the values of each non-leaf node. It is guaranteed this sum fits into a 32-bit integer.

Example 1:

Input: arr = [6,2,4]
Output: 32
Explanation:
There are two possible trees.  The first has non-leaf node sum 36, and the second has non-leaf node sum 32.

    24            24
   /  \          /  \
  12   4        6    8
 /  \               / \
6    2             2   4
 

Constraints:

2 <= arr.length <= 40
1 <= arr[i] <= 15
It is guaranteed that the answer fits into a 32-bit signed integer (ie. it is less than 2^31).

Solution

We can solve the problem using dynamic programming. Denote DP[i][j] be the answer for subarray arr[i] to arr[j] with both i, j inclusive. Then DP[i][j] = DP[i][k] + DP[k + 1][j] + max(arr[i:k+1])*max(arr[k+1:j+1])

Below is my python implementation

class Solution:
    def mctFromLeafValues(self, arr: List[int]) -> int:
        '''
        left, resleft = mctFromLeafValues(arrleft)
        right, resright = mctFromLeafValues(arrright)
        root = left*right
        res = resleft + resright + root
        return max(res)
        
        mctFromLeafValues(arr[i:j])
        
        '''
        '''
        Top down DP
        
        import functools
        
        @functools.lru_cache(maxsize=128)
        def DP(i, j)->(int, int):
            if i == j:
                return arr[i], 0
            res = float('inf')
            root = 0
            for k in range(i, j):
                left, leftres = DP(i, k)
                right, rightres = DP(k + 1, j)
                if left * right + leftres + rightres < res:
                    res = left * right + leftres + rightres
                    mx = max(left, right)
            return mx, res
        
        return DP(0, len(arr) - 1)[1]
        
        Now bottom up DP
        '''
        n = len(arr)
        DP = [[(-float('inf'),float('inf')) for j in range(n)] for i in range(n)]
        
        #base case l = 1
        for i in range(0, n):
            DP[i][i] = (arr[i], 0)
            
        for l in range(2, n + 1):
            for i in range(0, n + 1 - l):
                j = l + i - 1
                res = float('inf')
                mx = -float('inf')
                for k in range(i, j):
                    left, leftres = DP[i][k]
                    right, rightres = DP[k + 1][j]
                    if left * right + leftres + rightres < res:
                        res = left * right + leftres + rightres
                        mx = max(left, right)
                    DP[i][j] = (mx, res)
        return DP[0][n - 1][1]
				
comments powered by Disqus