準備好資料集後(Day10),接下來我們要使用 torchvision
中的 datasets
建立資料集,首先要匯入相關套件:
from torchvision import datasets
import torch
接著建立路徑的變數(此處使用相對路徑):
train_path = './cat-vs-rabbit/train-cat-rabbit'
test_path = './cat-vs-rabbit/test-images'
val_path = './cat-vs-rabbit/val-cat-rabbit'
ImageFolder 方法用來讀取資料夾底下的圖片,且會自動將資料夾的名稱當作類別,可以參考官方的範例:
dataset = datasets.ImageFolder(train_path)
print(dataset.class_to_idx)
print(len(dataset))
透過 class_to_idx
顯示類別;len()
來查看資料集的數量:
要確認資料讀近來對不對最快的方法就是顯示出來,程式碼如下:
from matplotlib import pyplot as plt
class_name = ['cat', 'rabbit']
# 從資料集取出第 100 張
img, label = dataset[100]
# 顯示圖片的類別
print(class_name[label])
plt.imshow(img)
plt.show()
可以看到輸出的類別貓咪且對應的圖片也是貓咪:
因為前800張是貓咪,所以我們取第 1000 張看看結果:
# 從資料集取出第 1000 張
img, label = dataset[1000]
# 顯示圖片的類別
print(class_name[label])
plt.imshow(img)
plt.show()
輸出的類別兔子且對應的圖片也是兔子:
今天透過 ImageFolder
的方法將圖片路徑和對應的類別建立成物件,明天會介紹如何利用這個資訊來建立 Dataloader
。