今天要介紹給大家的資料集是cifar10,資料集內含10個類別的圖片,分別是飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車,其與mnist的主要不同之處在於維度,cifar10是彩色圖片的資料集,有三個channel,mnist的黑白圖片僅有一個channel,我們先透過keras載入這個資料集,並看看它的shape如何。
from tensorflow.keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(r'x_train.shape = ', x_train.shape)
print(r'y_train.shape = ', y_train.shape)
print(r'x_test.shape = ', x_test.shape)
print(r'y_test.shape = ', y_test.shape)
x_train.shape = (50000, 32, 32, 3)
y_train.shape = (50000, 1)
x_test.shape = (10000, 32, 32, 3)
y_test.shape = (10000, 1)
一共有60000張32X32大小的彩色圖片,被分成50000張的訓練集跟10000張的測試集,每個類別有6000張圖片,已經被依照比例分配至訓練集和測試集。
需要注意的是channel的先後問題,如果你有自己製作或在網路下載別人的資料集,需要注意資料集的shape,分為channels_first與channels_last兩種,tensorflow預設為channels_last,以cifar10作為例子,為大家展示這兩種shape的不同。
channels_first = (60000, 3, 32, 32)
channels_last = (60000, 32, 32, 3)
依照慣例上個圖片給大家看,如果沒有字體而報錯的,請在fontproperties自行更換字體名稱。
import matplotlib.pyplot as plt # pip install matplotlib
from random import randrange
text = ['飛機', '汽車', '鳥' ,'貓', '鹿', '狗', '青蛙', '馬', '船', '卡車']
plt.figure(figsize=(16,10),facecolor='w')
for i in range(5):
for j in range(8):
index = randrange(0, 50000)
plt.subplot(5, 8, i*8+j+1)
plt.title("label: {}".format(text[y_train[index][0]]), fontproperties="Microsoft YaHei")
plt.imshow(x_train[index])
plt.axis('off')
plt.show()
我們明天來開始對付這個資料集,它還有一個cifar100的兄弟,大家可以去搜尋看看。