*train.py
前個episode把data_generaotr.py裡頭的classDataset
的初始化部分講完,在train.py裡生成一個名為dataset的物件後,緊接著就呼叫了他的methodget_one_shot_iterator
。
今天就要接著從這個method開始講起,開始吧。
datasets/data_generator.py
(cont.)
字如其人,碼如其名(?)
這個method的用途就是取得tf.data的Iterator格式的資料,從呼叫這個程式開始,就會開始對tfrecord進行解析與處理。
323行呼叫了class內部的method,用途是根據347行的file pattern,將資料夾名稱與split結合,以取得tfrecord檔的路徑。
num_parallel_reads
是要一次從num_readers個tfrecord裡平行取出1個data來,所以有用這個argument,data會是interleave的方式取出;反之如果沒有設定的話,tfrecord檔會依次打開讀取。num_parallel_reads
的話,接下來的運算都需要用同樣的parallel數量。
325-328行還可以分開寫成:
這個parse function,接的是tf.data.TFRecordDataset的物件,這個物件會按照一個個tf.example進行解析。
所以他接的參數叫example_xxx。
要解析example,就要跟當初定義的key name一樣,然後value的部分放tf.feature(),並根據之前定義的方式,將變數的型態說明清楚。
接著將單個example,搭配他的檔案格式,用tf的運算,得到解析完成的feature。
238-243行,是分別將data與label的資料decode成為影像,還記得一般data會使用.jpg,label通常會用.png,為了一致性,238、242行皆使用一樣的function進行解析。並且因為不是每種split都需要label,像是test這個split就不需要label的存在,因此241行有額外做一個判斷式。
而jpg與png使用相同的function進行解析,就是用了包裝起來的程式去選擇要用的function。
tf.cond()是個方便的method,第一個參數是判別式,第二個參數是判別式為True要執行的;第三個參數則是判別式為False時要執行的。
感覺跟三元判斷子很像。
接著245行取得data的名稱,並且處理的時候讓image_name為None的tfrecord也還是可以通過。
要注意的是,在tensorflow裡面,常數的宣告是需要使用tf.constant()的。
今天先到這。