iT邦幫忙

2021 iThome 鐵人賽

DAY 7
0
AI & Data

30 天在 Colab 嘗試的 30 個影像分類訓練實驗系列 第 7

【7】Dataset 的三個API : Shuffle Batch Repeat 如果使用順序不同會產生的影響

  • 分享至 

  • xImage
  •  

Colab連結

今天的主題比較特殊一些,要來探討 tensorflow 中的 Dataset api : shuffle, batch 和 repeat 的順序,在我一開始使用這個API時,我完全沒想到他的順序會完全影響到訓練的結果而踩了好大的一坑。

Shuffle:
顧名思義,就是用來打亂資料集的API,只是需要注意的是在使用此 API 時,必須給予 buffer_size ,其用途是執行 shuffle 時,他並不是把全部的資料做 shuffle ,而是只會把前N個資料做 shuffle,這個N的數量就是 buffer_size 。

Batch:
前幾天的實驗已經用過了,就是將原本分散一筆一筆的資料以批次的方式包起來,每個訓練的step就會拿到同樣 batch size 的樣本來訓練。

Repeat:
其功能為要重覆這個 dataset 的元素幾次,如果是count=2,那你可以做到在一個 epoch 內對每筆資料掃過兩次的效果,當count=None時,即代表無限重覆,此時要注意你在 model.fit() 時,必須指定 steps_per_epoch,不然會永遠算不玩一個 epoch 而錯誤!

介紹完這三支 API ,以下就來實驗不同的組合之下,會有什麼樣的效果啦!Dataset 很簡單,就是1~13

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])

實驗一:Shuffle -> Batch -> Repeat

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
ds = ds.repeat()

for example in ds.take(12):
  batch = example.numpy()
  print(batch)

產出:

[13 11  9  4]
[12 10  3  5]
[2 1 6 8]
[7]
[ 6  7  4 10]
[11  2  5 13]
[ 9  1 12  3]
[8]
[ 5  4  8 11]
[ 6 13  3 10]
[7 1 2 9]
[12]

數據有成功 shuffle,但因為 batch size=4 無法整除13,所以會有個 batch 只有一筆資料。

實驗二:Repeat -> Shuffle -> Batch

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.repeat()
ds = ds.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)

for example in ds.take(12):
  batch = example.numpy()
  print(batch)
  

產出:

[11  9  6  3]
[1 1 4 5]
[ 2 12  2  3]
[10 13  7  8]
[2 1 4 3]
[ 9  4 11 10]
[ 7  5 10 13]
[ 8  2 13  6]
[11  4 12  1]
[10  6 12  5]
[6 7 8 8]
[3 5 9 7]

我們發現第二個batch 有兩個重複的1,這是因為資料集是先被 repeat 後才開始 shuffle ,所以第一次 batch 拿走 [11, 9, 6, 3]後,有兩個1都在 shuffle 的 buffer size 範圍裡,所以就有可能被同時拿出來,而也因為被 repeat 過,變成無限長的資料集,所以 batch 後不會像上面產生只有一筆資料的 batch。

實驗三:Shuffle -> Repeat -> Batch

BATCH_SIZE=4
SHUFFLE_SIZE=13

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for example in ds.take(12):
  batch = example.numpy()
  print(batch)

產出:

[ 7 11  2  4]
[ 6  5  9 13]
[ 8  1 12 10]
[ 3  6  5 12]
[4 7 8 9]
[11 13  3  1]
[ 2 10  3  4]
[ 5 11  9  2]
[12  6  1  8]
[ 7 10 13 12]
[11 10  2  7]
[ 5  9 13  6]

這個順序可以看到每個 batch 都是四個,而且每個數字在第二次被拿出來前,都有至少歷經完整個資料集,相對實驗二來說,每個資料集被拿到的機率平均了一些,這種組合是我自己比較常用的組合。

再來!我們要探討 shuffle 的 buffer size 設置問題!

x = np.array(range(100))
x = x.repeat(10)
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)

首先,我先準備好資料集,有數字0~99,每個重複10遍,用 list 還看大致長這樣:[ 0,0,0...1,1,1...2,2,2......99,99,99 ] 共1000個元素

現在,我們 buffer size 取 10,也就是故意讓它只對前10個元素做 shuffle,看看會發生什麼事?

SHUFFLE_SIZE = 10

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

產出:

[0 0 0 0]
[0 1 0 1]
[1 1 0 0]
[2 1 0 0]
[2 1 2 1]
[1 2 3 2]
[3 1 3 1]
[2 3 3 2]
(略)

沒錯!前面幾個 step 拿到的數值都是很前面的元素,這樣的 shuffle 效果不彰...

若我們把 buffer size 增大到100又會發生什麼事呢?

SHUFFLE_SIZE=100

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

產出:

[3 4 5 0]
[4 1 4 9]
[2 6 3 7]
[10  8  4  3]
[2 0 3 7]
[5 0 4 0]
[ 8  0 11  7]
[11  6  0 12]
[11 10  6  2]
[13 11  9  6]
(略)

有比上一個實驗有變化些,但是仍然沒有拿到比較尾端90,91,92..等元素。因此問題來了,如果今天我的 dataset 數量龐大,我如果把 buffer size 設定的和 dataset 數量一樣,結果遇上記憶體 OOM 問題而程式炸掉,那我應該怎麼辦?

其實更好的做法還是在包成 tfrecord 之前就先把資料打亂,我這邊以 numpy 簡單為例,在我把它變成 tf.data.Dataset 前,先自行用 np.random.shuffle 做 shuffle 後再使用 API。

x = np.array(range(100))
x = x.repeat(10)
np.random.shuffle(x) # shuffle
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
SHUFFLE_SIZE=100

ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)

for idx, example in enumerate(ds.take(100)):
  batch = example.numpy()
  print(batch)

產出:

[19 12 88 32]
[13 43 76 91]
[85 24 63 58]
[48 52 44 82]
[82 58 46 26]
[24 20 85 63]

可以看到有拿到85, 91等較為後面的元素。

以上就是使用 shuffle, repeat, batch 這三支 API 需要注意的地方。


上一篇
【6】為什麼 Batch size 通常都是設成2的n次方
下一篇
【8】資料集有沒有事先 shuffle 對訓練所產生的影響
系列文
30 天在 Colab 嘗試的 30 個影像分類訓練實驗31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言