iT邦幫忙

2024 iThome 鐵人賽

DAY 4
1
AI/ ML & Data

【AI筆記】30天從論文入門到 Pytorch 實戰系列 第 4

【AI筆記】30天從論文入門到 Pytorch 實戰:資料集讀取技巧與最佳實踐 Day 3

  • 分享至 

  • xImage
  •  

了解資料結構

Kaggle資料集
範例為貓狗辨識的例子,要確認以下幾點:

  1. 檔案路徑
  2. 標籤位子
  3. data檔案格式
  4. 訓練和測試資料集是否分開?
  5. 如果混在一起的話,就要自己切割,用套件或是自己切資料

資料集結構

  • /dataset
    • /train_set
      • /cats
        • cat.4001.jpg
        • cat.4002.jpg
        • cat.4003.jpg
        • cat.4004.jpg
        • ...
      • /dogs
        • ...jpg
    • /test_set
      • /cats
        • cat.4001.jpg
        • cat.4002.jpg
        • cat.4003.jpg
        • cat.4004.jpg
      • /dogs
        • ...jpg

根據資料的特性選擇合適的資料結構。例如: 對於結構化資料,可以使用 DataFrame(如 pandas);對於圖像資料,可以使用 NumPy 數組或 PyTorch 的 Tensor。這邊採用Tensor。

train_dir = '/kaggle/working/dataset/train_set/'
test_dir  = '/kaggle/working/dataset/test_set/'
class dataset(torch.utils.data.Dataset):
    def __init__(self, file_path, transform = None):
        self.file_path = file_path
        self.file_list = glob.glob(file_path+'*/*.jpg')
        self.transform = transform
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self,idx):
        img_path = self.file_list[idx]
        image = Image.open(img_path)
        if self.transform:
            img = self.transform(image)
        
        label = img_path.split('/')[-1].split('.')[0]
        if label == 'dog':
            label = 1
        elif label == 'cat':
            label = 0
        
        return img, label
train_dataset = dataset(train_dir, transform)
test_dataset = dataset(test_dir, transform)

dataset 要回傳一組 (img, label)

比如現在idx=50,看到 getitem 的 function

  1. img_path = self.file_list[idx]:他會抓 file_list中第51個資料
  2. 將圖片打開 存到 image 變數中
  3. transform(image) 轉成 Tensor格式
  4. label的部分因為 img_path 中有 label,所以直接擷取 [dog/cat]
  5. 因為 model 只能接收數字,所以要額外將文字的 label 轉成 [1/0]

資料預處理

在讀取資料時進行預處理(如標準化、正則化等),可以確保資料的一致性和質量。這可以在自定義的 Dataset 類中實現。

transform

  1. Tensor 格式:PyTorch 的模型需要吃 Tensor 格式的數據,這是因為 Tensor 是 PyTorch 的基本數據結構,支持自動微分和 GPU 加速計算。將數據轉換為 Tensor 格式後,才能進行梯度計算和參數更新。
  2. Normalize:對數據進行標準化(Normalize)可以讓模型更容易訓練。標準化的過程通常是將數據的每個通道減去均值並除以標準差,這樣可以使數據的分佈更均勻,從而加速收斂並提高模型的性能。
  3. Resize:由於每張圖片的大小可能不同,而模型只能接受固定大小的輸入,因此需要將圖片調整為相同的大小。這樣可以確保所有輸入數據的形狀一致,便於批量處理和模型訓練。
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean = [0.4883, 0.4551, 0.4170], std = [0.2276, 0.2230, 0.2233]),
            transforms.Resize([64, 64])
        ])

資料增強 Data Augmentation

為什麼要用到 Data Augmentation?

  1. 增加資料量
    資料增強可以有效地增加訓練資料的數量,這對於小規模資料集特別有用。通過生成更多的資料樣本,模型可以學習到更多的特徵,從而提高其性能。

  2. 提高模型的泛化能力
    資料增強可以幫助模型學習到更廣泛的特徵,從而提高其在未見過的資料上的表現。這是因為資料增強引入了資料的多樣性,使模型不僅僅依賴於訓練資料中的特定模式。

  3. 減少過擬合
    過擬合是指模型在訓練資料上表現良好,但在測試資料上表現不佳。資料增強通過引入隨機變換,可以使模型更難記住訓練資料中的具體樣本,從而減少過擬合的風險。

  4. 模擬真實世界的變化
    在實際應用中,資料可能會受到各種變化的影響,例如旋轉、縮放、平移等。資料增強可以模擬這些變化,使模型在面對真實世界資料時更加穩健。

在明天章節會詳細介紹各種 Data Augmentation 方法

# 定義資料增強
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.4883, 0.4551, 0.4170], std = [0.2276, 0.2230, 0.2233]),
    transforms.Resize([64, 64])
])

資料分割

將資料集分割為訓練集、驗證集和測試集,以便在不同階段評估模型的性能。可以使用 sklearn 的 train_test_split 函數來實現,也可以使用torch的。

train_set, val_set = torch.utils.data.random_split(train_dataset, [7000, 1000])
# 7000筆,1000筆

常見的分割比例

  1. 小規模資料集(幾萬量級):
  • 訓練集:60%
  • 驗證集:20%
  • 測試集:20%
  1. 大規模資料集(百萬量級以上):
  • 訓練集:98%
  • 驗證集:1%
  • 測試集:1%

分割比例的考量因素

  1. 資料集大小:
    對於小規模資料集,通常需要更多的資料來訓練模型,因此訓練集的比例會較大。
    對於大規模資料集,驗證集和測試集的數量只需要足夠大以進行可靠的評估即可。
  2. 模型的複雜度:
    如果模型的超參數較多或需要頻繁調整,則需要較大的驗證集來進行調參。
    如果模型較為簡單,則可以減少驗證集的比例。
  3. 交叉驗證:
    使用交叉驗證(如 K-fold)可以更有效地利用資料,特別是在資料量較少的情況下。這種方法可以降低資料劃分帶來的影響,並提供更穩定的模型評估結果。

紀錄

import tqdm

常用在訓練時,可以幫忙看大概一個 epoch 耗時需要用多久時間。方便計算大概要花多久時間才能完成訓練。

pytorch 要自己寫才會產生進度條,如果是 pytorch lightning 就有預設可以直接使用(包含log file)。

from tqdm import tqdm
import time

epoches = 50
for epoch in tqdm(range(epoches)):
    time.sleep(50)

import logging

可以用 logger.info 紀錄每個 iteration 的重要參數和指標。

# 設置logging模組
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='data_reading.log', filemode='w')

logger = logging.getLogger("logger_name")
logger = logging.StreamHandler()
logger.setLevel(logging.DEBUG)
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base LearningRate: {:.2e}".format(epoch, (n_iter + 1), len(train_loader), loss.avg, acc.avg, base_lr))

小結

  • 訓練集:用於訓練模型,通常佔據資料集的大部分。
  • 驗證集:用於調整模型的超參數和選擇最佳模型。
  • 測試集:用於最終評估模型的性能,確保模型在未見過的資料上表現良好。
  • 資料增強:在讀取數據時進行資料增強(如旋轉、翻轉、裁剪等),可以提高模型的泛化能力。這可以使用像 albumentations 這樣的庫來實現。
  • 記錄和監控:記錄數據讀取過程中的關鍵信息(如讀取時間、資料大小等),並進行監控,以便及時發現和解決問題。

Reference


上一篇
【AI筆記】30天從論文入門到 Pytorch 實戰:Pytorch 訓練流程全解析 Day 2
下一篇
【AI筆記】30天從論文入門到 Pytorch 實戰:資料預處理的關鍵步驟 Day 4
系列文
【AI筆記】30天從論文入門到 Pytorch 實戰26
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言