iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 12
0
Google Developers Machine Learning

How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文系列 第 12

【12】tensorflow 資料集應用:製作 tfrecord 篇

這次要介紹的是另一個比較偏向 Dataset 的主題,當你用 tensorflow 久了,你可能就會發現 tensorflow 就像一把瑞士刀,除了最主要的訓練功能以外,你還有 tensorboard 可以檢視,所以當然,在 Dataset 這塊,官方也做了不少支援,今天要介紹的功能,就是如何產生專屬的 tfrecord。

準備:
我這邊假定我是一位電商,我手上有9位會員的資料,資料包括名稱(A~I)、照片、年紀、身高、偏好產品的索引值(1~5),那們我該如何將以上資料轉換成tfrecord呢?

name_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I']
image_list = ['001-boy.png', '002-man.png', '003-man-1.png', 
            '004-man-2.png', '005-man-3.png', '006-girl.png', 
            '007-boy-1.png', '008-man-4.png', '009-girl-1.png']
age_list = [12, 33, 25, 55, 40, 31, 14, 37, 10]
height_list = [140.2, 174.6, 165.1, 170.9, 168.2, 177.8, 153.1, 164.3, 134.1]
prefer_prods_list = [[1, 2], [1, 5], [2], [3, 4], [1, 3, 5], [5], [], [1, 2], [2, 4]]

會員頭像:
https://ithelp.ithome.com.tw/upload/images/20190920/20107299pZV6nEfOJo.png

因為要寫入檔案,所以必須先產生一個 writer 來執行這件事情,我們可以呼叫 tf.python_io.TFRecordWriter,並指定路徑。

writer = tf.python_io.TFRecordWriter('../tfrecord/member.tfrecord')

接著,我們對會員資料用一個 for 迴圈拿取各個會員資料。

for i, name in enumerate(name_list):
    member_name = name.encode('utf8')
    image = image_list[i]
    age = age_list[i]
    height = height_list[i]
    prefer_prods = prefer_prods_list[i]

    with tf.gfile.GFile(os.path.join('my-icons-collection', 'png', image), 'rb') as fid:
        encoded_image = fid.read()

需要注意的是,tfrecord 裡的資訊是以一個 example 為單位,你可以把這個 example 想像成是一個 dictionary,裡面以(key ,value)方式儲存。

還有,每個 example 裡的 feature,一律使用 list 儲存,即使該 key 只存一個值,我們仍必須把該值以 list 型態包起來,因此我們先定義以下方法。

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    
def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
def float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

有了上面的方法之後,我們在包 example 就不必特別小心有沒有漏包 list 的狀況,我們把包 example 這件事情宣告成一個方法。

def data_to_example(name, encoded_image, age, height, prefer_prods):
   example = tf.train.Example(features=tf.train.Features(feature={
       'member/name': bytes_feature(name),
       'member/encoded': bytes_feature(encoded_image),
       'member/age': int64_feature(age),
       'member/height': float_feature(height),
       'member/prefer_prods': int64_list_feature(prefer_prods),
   }))
   return example

回到原本的for迴圈,我們可以透過剛剛宣告的方法將example由writer寫入檔案。

tf_example = data_to_example(member_name, encoded_image, age, height, prefer_prods)
writer.write(tf_example.SerializeToString())

這樣就完成囉,你會在專案的tfevent資料夾裡看到檔案。
https://ithelp.ithome.com.tw/upload/images/20190920/201072994bGV1d88Rd.png

製作 tfrecord 的部分就到這邊,比較要注意的點大概就是 example 的值要用 list 包起來。

github原始碼


上一篇
【11】從 tensorboard 來觀察:dropout 原理篇
下一篇
【13】tensorflow 資料集應用:讀取 tfrecord 篇
系列文
How to Train Your Model 訓模高手:我的 Tensorflow 個人使用經驗系列文31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言