Sum of Distances in Tree (LC834)

  25 May 2019

By Wen Xu

leetcode algorithm depth first search tree

Problem description

An undirected, connected tree with N nodes labelled 0…N-1 and N-1 edges are given.

The ith edge connects nodes edges[i][0] and edges[i][1] together.

Return a list ans, where ans[i] is the sum of the distances between node i and all other nodes.

Example 1:

Input: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: 
Here is a diagram of the given tree:
  0
 / \
1   2
   /|\
  3 4 5
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.  Hence, answer[0] = 8, and so on.
Note: 1 <= N <= 10000

Solution

  1. If there are two nodes separated by an edge x, y. Then the distance from node x can be calculated using the distance from y. Let f(x) be the subtree nodes from node x (including x itself). Then the distance from x : dist(x) = dist(y) - f(x) + (N - f(x)).
  2. The equation is because when we move from node y to x, the decrease distance by 1 for each node in subtree rooted at x. There are f(x) nodes in that tree, so we need to subtract 1xf(x).
  3. Same time, we add distance by 1 for all nodes other than the f(x) nodes in subtree rooted at x. There are N - f(x) nodes. So we need to add 1x(N-f(x)). N is the total number of nodes in the tree.
  4. So, we first need to use dfs to calculate a correct f(x) for one node. We choose that to be 0
  5. Then we use a second dfs to calculate the correct f(x) for all nodes in the tree rooted at 0

Below is my python implementation

class Solution:
    def sumOfDistancesInTree(self, N: int, edges: List[List[int]]) -> List[int]:
        children = collections.defaultdict(list)
        for edge in edges:
            x, y = edge
            children[x].append(y)
            children[y].append(x)
        ans = [0 for i in range(N)] #distance start from node i
        count = [0 for i in range(N)]#children of node i including i itself
        
        def dfs1(node, parent)->int:
            if children[node] == []:
                ans[node] = 0
                count[node] = 1
                return 0, 1
            res = 0
            cn = 1
            for child in children[node]:
                if child != parent:
                    depth, c = dfs1(child, node)
                    res += depth + c * 1
                    cn += c
            ans[node] = res
            count[node] = cn
            return res, cn
        
        def dfs2(node, parent)->None:
            for child in children[node]:
                if child != parent:
                    ans[child] = ans[node] - count[child] + N - count[child]
                    dfs2(child, node)
            return
        
        dfs1(0, None)
        
        dfs2(0, None)
        
        return ans
				
comments powered by Disqus