Smallest Subtree with all the Deepest Nodes (LC865)

  29 May 2019

By Wen Xu

leetcode algorithm binary tree depth first search

Problem description

Given a binary tree rooted at root, the depth of each node is the shortest distance to the root.

A node is deepest if it has the largest depth possible among any node in the entire tree.

The subtree of a node is that node, plus the set of all descendants of that node.

Return the node with the largest depth such that it contains all the deepest nodes in its subtree.

Example 1:

Input: [3,5,1,6,2,0,8,null,null,7,4] Output: [2,7,4] Explanation:

We return the node with value 2, colored in yellow in the diagram. The nodes colored in blue are the deepest nodes of the tree. The input “[3, 5, 1, 6, 2, 0, 8, null, null, 7, 4]” is a serialization of the given tree. The output “[2, 7, 4]” is a serialization of the subtree rooted at the node with value 2. Both the input and output have TreeNode type.

Solution

  1. In the first step, we can find the depth of the tree using a dfs.
  2. In the second step, we can find the set of nodes that with depth equal the depth of the entire tree.
  3. Now the problem is reduced to find the lowest common ancestors of the nodes set in step 2. We can track the parent of the nodes set, until all parents reduced to single node.

Below is my python implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def subtreeWithAllDeepest(self, root: TreeNode) -> TreeNode:
        
        def findD(root:TreeNode)->int:
            if not root:
                return 0
            return 1 + max(findD(root.left), findD(root.right))
        
        dt = findD(root)
        
        nodes = set([])
        parents = {}
        
        def finddepth(root:TreeNode, d, parent)->None:
            if not root:
                return
            parents[root] = parent
            if d == dt:
                nodes.add(root)
            finddepth(root.left, d + 1, root)
            finddepth(root.right, d + 1, root)
            return
        
        finddepth(root, 1, None)
        
        while len(nodes) > 1:
            tempnodes = set([])
            for node in nodes:
                tempnodes.add(parents[node])
            nodes = tempnodes
            
        return nodes.pop()
				
comments powered by Disqus