這次要介紹的是另一個比較偏向 Dataset 的主題,當你用 tensorflow 久了,你可能就會發現 tensorflow 就像一把瑞士刀,除了最主要的訓練功能以外,你還有 tensorboard 可以檢視,所以當然,在 Dataset 這塊,官方也做了不少支援,今天要介紹的功能,就是如何產生專屬的 tfrecord。
準備:
我這邊假定我是一位電商,我手上有9位會員的資料,資料包括名稱(A~I)、照片、年紀、身高、偏好產品的索引值(1~5),那們我該如何將以上資料轉換成tfrecord呢?
name_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
image_list = ['001-boy.png', '002-man.png', '003-man-1.png',
'004-man-2.png', '005-man-3.png', '006-girl.png',
'007-boy-1.png', '008-man-4.png', '009-girl-1.png']
age_list = [12, 33, 25, 55, 40, 31, 14, 37, 10]
height_list = [140.2, 174.6, 165.1, 170.9, 168.2, 177.8, 153.1, 164.3, 134.1]
prefer_prods_list = [[1, 2], [1, 5], [2], [3, 4], [1, 3, 5], [5], [], [1, 2], [2, 4]]
會員頭像:
因為要寫入檔案,所以必須先產生一個 writer 來執行這件事情,我們可以呼叫 tf.python_io.TFRecordWriter,並指定路徑。
writer = tf.python_io.TFRecordWriter('../tfrecord/member.tfrecord')
接著,我們對會員資料用一個 for 迴圈拿取各個會員資料。
for i, name in enumerate(name_list):
member_name = name.encode('utf8')
image = image_list[i]
age = age_list[i]
height = height_list[i]
prefer_prods = prefer_prods_list[i]
with tf.gfile.GFile(os.path.join('my-icons-collection', 'png', image), 'rb') as fid:
encoded_image = fid.read()
需要注意的是,tfrecord 裡的資訊是以一個 example 為單位,你可以把這個 example 想像成是一個 dictionary,裡面以(key ,value)方式儲存。
還有,每個 example 裡的 feature,一律使用 list 儲存,即使該 key 只存一個值,我們仍必須把該值以 list 型態包起來,因此我們先定義以下方法。
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
有了上面的方法之後,我們在包 example 就不必特別小心有沒有漏包 list 的狀況,我們把包 example 這件事情宣告成一個方法。
def data_to_example(name, encoded_image, age, height, prefer_prods):
example = tf.train.Example(features=tf.train.Features(feature={
'member/name': bytes_feature(name),
'member/encoded': bytes_feature(encoded_image),
'member/age': int64_feature(age),
'member/height': float_feature(height),
'member/prefer_prods': int64_list_feature(prefer_prods),
}))
return example
回到原本的for迴圈,我們可以透過剛剛宣告的方法將example由writer寫入檔案。
tf_example = data_to_example(member_name, encoded_image, age, height, prefer_prods)
writer.write(tf_example.SerializeToString())
這樣就完成囉,你會在專案的tfevent資料夾裡看到檔案。
製作 tfrecord 的部分就到這邊,比較要注意的點大概就是 example 的值要用 list 包起來。