Problem description
An undirected, connected tree with N nodes labelled 0…N-1 and N-1 edges are given.
The ith edge connects nodes edges[i][0] and edges[i][1] together.
Return a list ans, where ans[i] is the sum of the distances between node i and all other nodes.
Example 1:
Input: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation:
Here is a diagram of the given tree:
0
/ \
1 2
/|\
3 4 5
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8. Hence, answer[0] = 8, and so on.
Note: 1 <= N <= 10000
Solution
- If there are two nodes separated by an edge x, y. Then the distance from node x can be calculated using the distance from y. Let f(x) be the subtree nodes from node x (including x itself). Then the distance from x : dist(x) = dist(y) - f(x) + (N - f(x)).
- The equation is because when we move from node y to x, the decrease distance by 1 for each node in subtree rooted at x. There are f(x) nodes in that tree, so we need to subtract 1xf(x).
- Same time, we add distance by 1 for all nodes other than the f(x) nodes in subtree rooted at x. There are N - f(x) nodes. So we need to add 1x(N-f(x)). N is the total number of nodes in the tree.
- So, we first need to use dfs to calculate a correct f(x) for one node. We choose that to be 0
- Then we use a second dfs to calculate the correct f(x) for all nodes in the tree rooted at 0
Below is my python implementation
class Solution:
def sumOfDistancesInTree(self, N: int, edges: List[List[int]]) -> List[int]:
children = collections.defaultdict(list)
for edge in edges:
x, y = edge
children[x].append(y)
children[y].append(x)
ans = [0 for i in range(N)] #distance start from node i
count = [0 for i in range(N)]#children of node i including i itself
def dfs1(node, parent)->int:
if children[node] == []:
ans[node] = 0
count[node] = 1
return 0, 1
res = 0
cn = 1
for child in children[node]:
if child != parent:
depth, c = dfs1(child, node)
res += depth + c * 1
cn += c
ans[node] = res
count[node] = cn
return res, cn
def dfs2(node, parent)->None:
for child in children[node]:
if child != parent:
ans[child] = ans[node] - count[child] + N - count[child]
dfs2(child, node)
return
dfs1(0, None)
dfs2(0, None)
return ans