使用Dataset API可以讓我們在巨大的資料集上做訓練(沒辦法一次讀進記憶體的資料大小),概念上是因為訓練的時候我們只需要小批次的資料,所以可以將原本的大資料拆分的數個較小的資料,下面是範例的程式。
最下面 model.train()
的部分是開始訓練,並且需要給入 input function。而 input_fn()
裡面使用到了 dataset 的 iterator,其中 dataset
就是從Dataset API:tf.data
來將資料做拆分,並且可以搭配 .shuffle()
、.repeat()
、.batch()
來對資料做、洗牌、批次大小等動作。
那麼為什麼是使用iterator呢?這邊值得再多探討一些TensorFlow運作的模式,在宣告這類 “tf.XX” 的時候,實際上是建立運算的圖,這些圖只有在train或predict的時候才會被執行。建好模型的圖後,接著就是與輸入資料做連接,這就是input function的用處:回傳一個TensorFlow的node,其可以表示模型所需要的features和labels,如下圖示,並且當下個迭代(iteration)時,提供新的一筆批次資料。
所以簡單來說,Dataset API是用來傳送每個訓練迭代的小批次資料給input node 並且保證資料是逐步地讀取不會讓記憶體空間飽和。
在這個實作中,我們將學會:
登入GCP,開啟Notebooks後,複製課程 Github repo (如Day9的Part 1 & 2步驟)。
在左邊的資料夾結構,點進 training-data-analyst > courses > machine_learning > deepdive > 03_tensorflow,然後打開檔案 c_dataset.ipynb。
一開始就進入這次lab的重點,將input資料重構:
tf.data.Dataset.list_files()
是將所有符合pattern的資料檔名找出來.flat_map(tf.data.TextLineDataset)
是將每個檔案資料拆成多個文字行(text lines),既一對多的轉換.map(decode_csv)
則是將每個文字行讀進來,轉變成input features,這是個一對一的轉換。read_dataset()
後,再分別定義讀取train和valid的input function:我們將會在下個lab使用第三個重構,將評估放進訓練的過程中。
今天介紹了Dataset API和如何在巨量資料上做訓練,明天我們將實作如何 “使用Estimator API來分散式訓練模型”。