https://leetcode.com/problems/sum-of-distances-in-tree/
有一顆由n個節點(node)組成的無向樹,節點標示著0 到 n-1 的數字,且這棵樹的節點只會有一個邊(edge)連接。
我們要回傳每個點到其他點的距離(邊)之總和
節點0的距離總和是 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) --> 1 + 1 + 2 + 2 + 2 = 8
節點1的距離總和是 dist(1,0) + dist(1,2) + dist(1,3) + dist(1,4) + dist(1,5) --> 1 + 2 + 3 + 3 + 3 = 12
首先,先上程式碼吧!!!
class Solution(object):
def sumOfDistancesInTree(self, N, edges):
graph = collections.defaultdict(set)
for u, v in edges:
# 把每個有互相連接的nodes存起來
graph[u].add(v)
graph[v].add(u)
count = [1] * N
ans = [0] * N
def dfs(node = 0, parent = None):
for child in graph[node]:
if child != parent:
dfs(child, node)
count[node] += count[child] #子樹有幾個節點
ans[node] += ans[child] + count[child]
def dfs2(node = 0, parent = None):
for child in graph[node]:
if child != parent:
ans[child] = ans[node] - count[child] + N - count[child]
dfs2(child, node)
dfs()
dfs2()
return ans
這個解法有兩個function,我們一個個來看
首先,dfs() 有兩個任務:
第一個是計算每個節點的子樹有幾個節點
例如範例1: 節點2的子樹有4個節點
第二個是計算根(節點0)的距離總和,而其他節點則記下各自的子節點數量
為什麼要記下各自的子節點數量呢?
因為節點的距離總和計算方式是這樣: ans[i] = 獨立子樹的根的距離總和 + 獨立子樹的子節點數量
以節點1為例: ans[1] = 0 + 0
節點2為例: ans[2] = (0 + 1) + (0 + 1) + (0 + 1)
節點0就是: ans[0] = (0 + 1) + (3 + 4) = 8
到這裡為止其實已經算是解完這題了,只要用個迴圈讓dfs()輪流跑過每個節點並記錄答案就可以回傳了
但是這樣會超過系統限制的時間(Time Limit Exceeded),所以才需要dfs2()
這個function是用來計算每個節點的距離總和
而funciton裡面的下面這條公式讓我一直想不明白,所以昨天就決定先卡個位然後去睡覺了
ans[child] = ans[node] - count[child] + N - count[child]
今天起床後再看一次才終於知道它在幹嘛
這個公式要拆成兩半來看會比較清楚
我們以 ans[1] = ans[0] - count[1] + N - count[1] 來舉例
前半段的 ans[0] - count[1] 代表和節點1相比,有些點離節點0更遠而離節點1更近,因此用節點0的距離總和,減掉這些距離節點1更近的節點數量,就能得到節點1距離總合;而這些離節點1更近的點,其實就是節點1的子節點數量。
既然有些點離節點1更近,那其他節點就會反而離節點1更遠吧,所以才要加上 N - count[1]
連續兩天都是hard難度的題目讓我有點吃不消啊
過去做hard題目都覺得很困難,即使有答案可以看還是需要多想一下,所以都會偷懶避開不寫
不過,都參加鐵人賽了就硬幹下去吧 !!!