Leetcode 834: Sum of distances in Tree

dume0011
3 min readAug 2, 2023

--

There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.

Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.

Example 1:

Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.

Example 2:

Input: n = 1, edges = []
Output: [0]

Example 3:

Input: n = 2, edges = [[1,0]]
Output: [1,1]

Constraints:

  • 1 <= n <= 3 * 10^4
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • The given input represents a valid tree.

Problem Analysis:

As we have edges, we can compute the sum of distances of points through iterating. It’s simple, but the time complexity is O(n²).

To optimize, we can build it as a tree, and compute the sums of distances of every node (here we use the node as subtree’s root node). Note here we can choose arbitrary node as the root node of the tree.

We can see the sum of distances of the root node in the tree is the sum of distances of the node we wanted to get. Then we need to compute other nodes’ sums.

Then we use breadth search first algorithm. Assume we get the sum of distances of a node x is sum_x. for child node y of node x, the distance of node in subtree y to node y is less than 1, and others is more than 1, so sum_y = sum_x — count(nodes in subtree y) + count(nodes not in subtree y).

Then we get the answer.

Solution

class Solution {
public:
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
vector<unordered_set<int>> tree(n);
vector<int> res(n);
vector<int> count(n, 1);
for (const auto& item : edges) {
tree[item.front()].insert(item.back());
tree[item.back()].insert(item.front());
}
dfs(tree, res, count, 0, -1);
dfs2(tree, res, count, 0, -1);

return res;
}

void dfs(vector<unordered_set<int>>& tree, vector<int>& res,
vector<int>& count, int root, int pre) {
for (const auto& item : tree[root]) {
if (item == pre) continue;
dfs(tree, res, count, item, root);
count[root] += count[item];
res[root] += res[item] + count[item];
}
}

void dfs2(vector<unordered_set<int>>& tree, vector<int>& res,
vector<int>& count, int root, int pre) {
for (const auto& item : tree[root]) {
if (item == pre) continue;
res[item] = res[root] - count[item] + count.size() - count[item];
dfs2(tree, res, count, item, root);
}
}
};

Time complexity is O(n)

Space complexity is O(n)

--

--