iT邦幫忙

第 12 屆 iT 邦幫忙鐵人賽

DAY 30
0
AI & Data

從零.4開始我的深度學習之旅:從 用tf.data處理資料 到 用tf.estimator或tf.keras 訓練模型系列 第 30

二、教你怎麼看source code,找到核心程式碼 ep.23:Deeplab的model 部署

文章說明

文章分段

  1. 文章說明
  2. deeplab的簡單介紹、於我的意義 ep.1
  3. tensorflow的程式碼特色 ep.2
  4. 訓練流程的細節 ep.3
  5. 逛deeplab的github程式

 

前情提要

上個ep稍微去看了deployment.py後,大概知道他在處理各個clone拿到的model的tensor name,那代表,要真的了解model在幹麻,還是得回歸到train.py裡作為model_fn的_build_deeplab()才行。

build model的過程伴隨著資料的輸入,第一件做的事情就是從iterator中取出一個batch的訓練樣本。

上次就講到這些,讓我們繼續開始吧。

 

逛deeplab的github程式(cont.)

train.py (cont.)


ModelOptions是來自common.py的class,目的就是在製作要設定此次模型的參數。

  • outputs_to_num_classes就是此次分類的數量
  • crop_size是input tensor shape,之前iterator已經按照這個尺寸處理好資料了
  • atrous_rates就是deeplab的論文核心,空洞的部分,順帶一提,蠻多API的conv layer都已經有提供這個參數做使用。
  • output_stride是input的尺寸與output尺寸的比例,我記得應該是input/output

 


這個部分就要去看model.pymulti_scale_logits(),我一直蠻好奇這個「logits」指的究竟是什麼意思?我們通常搭建的模型會是前面有許多conv layer層去找特徵,然後後面接幾個全連接層將其展開,然後最後用一層輸出跟分類數量一樣的神經層去輸出最後的output。而這個logits,一般就是最後一層的「前一層」的輸出。

也就是說,主要的模型架構應該就在這個method裡面。

 


接下來是將模型的logits取出來,把他們存在一個dict裡,並且給他tensor name。

 


接著最後就是將logits送去給add_softmax_cross_entropy_loss_for_each_scale()去計算最後的輸出,並且跟label進行比較。

有沒有做到更新參數這件事要等看了這個method才知道。

 


所以接著去utils資料夾看train_utils.py的程式吧。

 

train.utils.py


對不同scale的logits,都去計算softmax cross entropy的loss。

其實在這個method的第74行~140行,都在處理scale的問題,讓logits與label有對應到,所以會有一些resize跟reshape的部分。

這裡其實還有一個重點,data與label使用的resize method不一樣,也不應該一樣,data適合使用bilinear這類interpolation的,而label則是適合nearest neighbor這類的resize。
畢竟不會有人希望,label resize後,居然出現1.5這種label...

 


最終就是把label跟logits拿去計算loss,不過這裡已經到tf.nn這個原始API了,身為上層API的使用戶,我已經無法繼續解釋下去了,所以只能在這裡打住。

 

噢然後在算完loss後,其實有個叫top_k的東西,可以用來幫忙提高更新loss的效率,就是top_k。

他的使用時機就是在算完loss後,把loss交給這個method,他就會去取loss最高的前k個,只把這些loss拿出來,然後去把這些地方做更新。

 


然後之後把loss加進tf.losses這個地方,沒意外的話,我想這些loss,應該是可以被tf.GraphKeys.LOSSES找到的。

這樣,應該都結束了吧。


上一篇
二、教你怎麼看source code,找到核心程式碼 ep.22:Deeplab的model 部署
系列文
從零.4開始我的深度學習之旅:從 用tf.data處理資料 到 用tf.estimator或tf.keras 訓練模型30

2 則留言

0
阿瑜
iT邦新手 4 級 ‧ 2020-11-02 15:55:21

未完待續 XD

我要留言

立即登入留言