iT邦幫忙

2023 iThome 鐵人賽

DAY 4
0
AI & Data

30天把AI知識傳授給女友系列 第 4

Day 4 透過 Pytorch 中的 torchvision 讀取資料

  • 分享至 

  • xImage
  •  

Pytorch 的流程

流程大致上可以分成以下六個步驟:

  1. 資料準備

  2. 建立模型

  3. 擬合模型到準備好的資料(Train)

  4. 評估模型(Evaluate)

  5. 做出預測(inference)

  6. 儲存模型

https://ithelp.ithome.com.tw/upload/images/20230909/20153503jfP0uIFxw8.png

今天會介紹如何下載和讀取官方提供的 dataset。

讀取準備好的資料

Pytorch提供兩種方式來讀取資料:

  1. torch.utils.data.DataLoader

  2. torch.utils.data.Dataset

Pytorch 有根據不同領域提供相關的library,例如TorchText(文字)、 TorchVision(視覺)和TorchAudio(影音),本偏會使用TorchVision的dataset當作範例。

torchvision.datasets 模組包含 Dataset 許多現實世界視覺數據的對象,如 CIFAR、COCO(有興趣的可以點這裡看更多)。這次會使用 FashionMNIST 數據集來做範例。

FashionMNIST

此資料集有十個類別如下,目的是讓機器學習學會分類:

  1. T-shirt/top

  2. Trouser

  3. Pullover

  4. Dress

  5. Coat

  6. Sandal

  7. Shirt

  8. Sneaker

  9. Bag

  10. Ankle boot

https://ithelp.ithome.com.tw/upload/images/20230909/20153503RAlKziaCWJ.png

實戰片段

首先我們要先將會使用到的 library 引入:

import torch
from torch import nn # 用來創建neural network的函式
from torch.utils.data import DataLoader # 用來讀取資料
from torchvision import datasets # 用來存取視覺相關資料集
from torchvision.transforms import ToTensor # 用來將資料轉換為pytorch格式

接著從透過 torchvision 下載公開資料集 FashionMNIST:

# 下載訓練資料
# train設定為true代表訓練資料
# root 設定為 "data" 代表會將資料下載到data資料夾裡面
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# 下載測試資料
# train設定為false代表訓練資料
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

https://ithelp.ithome.com.tw/upload/images/20230909/20153503hFoBWkvv9Z.png

使用 DataLoader 讀取資料(此動作會使CPU將存在硬碟的資料搬到記憶體中),batch_size 代表一次要放幾張圖片,這邊設定為64。

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

可以看到X的shape為64張,Channel為1(黑白為1、彩色為3)、長寬分別28:

https://ithelp.ithome.com.tw/upload/images/20230909/201535035xpX6Ay3VD.png

可是可是…我沒看到圖片阿! (optional)

引入matplotlib.pyplot套件,並將剛剛上面最後一次讀64張影像中的第一張顯示出來

import matplotlib.pyplot as plt
print(y[0]) # 順便秀出式第幾個類別
plt.imshow(X[0].permute(1, 2, 0))

可以看到類別是第九類,代表Ankle boot,你也可以嘗試修改X[改這裡]、y[改這裡],來看看其他圖案和對應的類別正不正確

https://ithelp.ithome.com.tw/upload/images/20230909/20153503fkJRR7aNaL.png

結語

今天介紹深度學習中的第一個環節讀取資料,本文使用官方提供的資料集,後續在實戰演練會在教大家怎麼讀取自己的資料。


上一篇
Day3 在巨人的肩膀上深度學習
下一篇
Day 5 使用Pytorch建立模型
系列文
30天把AI知識傳授給女友30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言