今天的主題比較特殊一些,要來探討 tensorflow 中的 Dataset api : shuffle, batch 和 repeat 的順序,在我一開始使用這個API時,我完全沒想到他的順序會完全影響到訓練的結果而踩了好大的一坑。
Shuffle:
顧名思義,就是用來打亂資料集的API,只是需要注意的是在使用此 API 時,必須給予 buffer_size ,其用途是執行 shuffle 時,他並不是把全部的資料做 shuffle ,而是只會把前N個資料做 shuffle,這個N的數量就是 buffer_size 。
Batch:
前幾天的實驗已經用過了,就是將原本分散一筆一筆的資料以批次的方式包起來,每個訓練的step就會拿到同樣 batch size 的樣本來訓練。
Repeat:
其功能為要重覆這個 dataset 的元素幾次,如果是count=2,那你可以做到在一個 epoch 內對每筆資料掃過兩次的效果,當count=None時,即代表無限重覆,此時要注意你在 model.fit() 時,必須指定 steps_per_epoch,不然會永遠算不玩一個 epoch 而錯誤!
介紹完這三支 API ,以下就來實驗不同的組合之下,會有什麼樣的效果啦!Dataset 很簡單,就是1~13
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
實驗一:Shuffle -> Batch -> Repeat
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
ds = ds.repeat()
for example in ds.take(12):
batch = example.numpy()
print(batch)
產出:
[13 11 9 4]
[12 10 3 5]
[2 1 6 8]
[7]
[ 6 7 4 10]
[11 2 5 13]
[ 9 1 12 3]
[8]
[ 5 4 8 11]
[ 6 13 3 10]
[7 1 2 9]
[12]
數據有成功 shuffle,但因為 batch size=4 無法整除13,所以會有個 batch 只有一筆資料。
實驗二:Repeat -> Shuffle -> Batch
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.repeat()
ds = ds.shuffle(SHUFFLE_SIZE)
ds = ds.batch(BATCH_SIZE)
for example in ds.take(12):
batch = example.numpy()
print(batch)
產出:
[11 9 6 3]
[1 1 4 5]
[ 2 12 2 3]
[10 13 7 8]
[2 1 4 3]
[ 9 4 11 10]
[ 7 5 10 13]
[ 8 2 13 6]
[11 4 12 1]
[10 6 12 5]
[6 7 8 8]
[3 5 9 7]
我們發現第二個batch 有兩個重複的1,這是因為資料集是先被 repeat 後才開始 shuffle ,所以第一次 batch 拿走 [11, 9, 6, 3]後,有兩個1都在 shuffle 的 buffer size 範圍裡,所以就有可能被同時拿出來,而也因為被 repeat 過,變成無限長的資料集,所以 batch 後不會像上面產生只有一筆資料的 batch。
實驗三:Shuffle -> Repeat -> Batch
BATCH_SIZE=4
SHUFFLE_SIZE=13
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for example in ds.take(12):
batch = example.numpy()
print(batch)
產出:
[ 7 11 2 4]
[ 6 5 9 13]
[ 8 1 12 10]
[ 3 6 5 12]
[4 7 8 9]
[11 13 3 1]
[ 2 10 3 4]
[ 5 11 9 2]
[12 6 1 8]
[ 7 10 13 12]
[11 10 2 7]
[ 5 9 13 6]
這個順序可以看到每個 batch 都是四個,而且每個數字在第二次被拿出來前,都有至少歷經完整個資料集,相對實驗二來說,每個資料集被拿到的機率平均了一些,這種組合是我自己比較常用的組合。
再來!我們要探討 shuffle 的 buffer size 設置問題!
x = np.array(range(100))
x = x.repeat(10)
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
首先,我先準備好資料集,有數字0~99,每個重複10遍,用 list 還看大致長這樣:[ 0,0,0...1,1,1...2,2,2......99,99,99 ] 共1000個元素
現在,我們 buffer size 取 10,也就是故意讓它只對前10個元素做 shuffle,看看會發生什麼事?
SHUFFLE_SIZE = 10
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
產出:
[0 0 0 0]
[0 1 0 1]
[1 1 0 0]
[2 1 0 0]
[2 1 2 1]
[1 2 3 2]
[3 1 3 1]
[2 3 3 2]
(略)
沒錯!前面幾個 step 拿到的數值都是很前面的元素,這樣的 shuffle 效果不彰...
若我們把 buffer size 增大到100又會發生什麼事呢?
SHUFFLE_SIZE=100
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
產出:
[3 4 5 0]
[4 1 4 9]
[2 6 3 7]
[10 8 4 3]
[2 0 3 7]
[5 0 4 0]
[ 8 0 11 7]
[11 6 0 12]
[11 10 6 2]
[13 11 9 6]
(略)
有比上一個實驗有變化些,但是仍然沒有拿到比較尾端90,91,92..等元素。因此問題來了,如果今天我的 dataset 數量龐大,我如果把 buffer size 設定的和 dataset 數量一樣,結果遇上記憶體 OOM 問題而程式炸掉,那我應該怎麼辦?
其實更好的做法還是在包成 tfrecord 之前就先把資料打亂,我這邊以 numpy 簡單為例,在我把它變成 tf.data.Dataset 前,先自行用 np.random.shuffle 做 shuffle 後再使用 API。
x = np.array(range(100))
x = x.repeat(10)
np.random.shuffle(x) # shuffle
print(f'length: {len(x)}')
dataset = tf.data.Dataset.from_tensor_slices(x)
SHUFFLE_SIZE=100
ds = dataset.shuffle(SHUFFLE_SIZE)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
for idx, example in enumerate(ds.take(100)):
batch = example.numpy()
print(batch)
產出:
[19 12 88 32]
[13 43 76 91]
[85 24 63 58]
[48 52 44 82]
[82 58 46 26]
[24 20 85 63]
可以看到有拿到85, 91等較為後面的元素。
以上就是使用 shuffle, repeat, batch 這三支 API 需要注意的地方。