iT邦幫忙

0

torch_geometric自建DataSet與Data在使用DataLoader用next時錯誤問題

  • 分享至 

  • xImage

狀況描述:
   建立模擬資料data_set時processed_data放入兩筆SketchData(stroke_idx均為9),再執行next()時先進入__len__再進入__getitem__,這個時候有個神奇的SketchDatasetBatch出現,到SketchData的__init__中,且所有參數都為空值(stroke_idx也變成None)。
   導致後續max(stroke_idx)跳出異常訊息。
   很不明白為何在放入processed_data時候明明有資料,再進入到__init__卻都是None/images/emoticon/emoticon02.gif
   若有人知道在懇請告知,感恩讚嘆您 /images/emoticon/emoticon41.gif

版本資訊:

  • python 3.9.15
  • pytroch 1.12.1

錯誤訊息:

  • 'NoneType' object is not iterable

程式碼:

from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import torch

class SketchData(Data):
    def __init__(self, stroke_idx=None, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, norm=None, face=None, **kwargs):
        super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
        self.stroke_idx = stroke_idx
        self.stroke_num = max(stroke_idx) + 1

class SketchDataset(torch.utils.data.Dataset):
    def __init__(self):
        # 節點特徵
        x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
        edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
        self.processed_data = []
        sketch_data = SketchData(x=x,
                                edge_index=edge_index.t().contiguous(),
                                stroke_idx=[2,4,6,8])
        self.processed_data.append(sketch_data)
        self.processed_data.append(sketch_data)

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, index):
        print(type(self.processed_data[index]))
        return self.processed_data[index]

data_set = SketchDataset()
data_loader = DataLoader(data_set, batch_size=1, shuffle=False)
batch  = next(iter(data_loader))
print(batch)
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

2 個回答

0
re.Zero
iT邦研究生 5 級 ‧ 2022-12-06 21:39:26
最佳解答

我不認識 pytroch,但因為你用到一個語法:

dataiter = iter(data_loader)
batch = dataiter.next()

讓我有點好奇而去查了下,感覺你的問題可能出在 "torch_geometric.loader.dynamic_batch_sampler" 的 def __iter__ 部分。
後續我就懶了,你應該能自行判斷、處理。

另,我看到 "Advanced Mini-Batching"
這裡是用 batch = next(iter(loader)) 請注意差異。

還有,請善用 markdown 語法與[預覽]功能,因為我很懶得猜原文是啥鬼~
(如果你是用手機就當我多言~)

zzssc061u iT邦新手 5 級 ‧ 2022-12-11 14:32:49 檢舉

感謝告知,已修正next與markdown語法。
後續有嘗試往自己建立Sampler方式,但是我本機沒有辦法from到,卡住了/images/emoticon/emoticon06.gif
https://pytorch-geometric.readthedocs.io/en/latest/modules/sampler.html

參考練習下方連結自訂Data使用[]包起來next(iter(loader))取值,是沒有問題的。
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
但如果使用自訂Data(或是Object)使用Dataset包起來,在Batch時經過Sampler就會無法辨識自訂的class。
錯誤訊息

Exception has occurred: TypeError
DataLoader found invalid type: <class '__main__.MyData'>

程式碼

import torch
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

class MyData(object):
    def __init__(self,num) -> None:
        self.num =num

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        for x in range(10):
            self.data.append(MyData(x))

    def __getitem__(self, index):
        print(index)
        return MyData(self.data[index].num)

    def __len__(self):
        return len(self.data)

dataset = MyDataset()
loader = DataLoader(dataset)

for _,data in enumerate(loader):
    print(data)
re.Zero iT邦研究生 5 級 ‧ 2022-12-11 16:02:38 檢舉

我用 Google:"pytorch" "DataLoader found invalid type" site:pytorch-geometric.readthedocs.io/en/latest時,
這裡看到 DataLoader 在初始化時 ,collate_fn 參數的預設值是 Collater 的實例,而 Collater 初始化時會確認 batch 的首元素之類別;若找無適當對應則觸發型別錯誤例外。
所以你想要讓自定義的 MyData 通過 DataLoader 的初始化,大概能從 collate_fn 下手,不然就是另尋資料綁定等其他方法了。
但是,設定 collate_fn 參數規避型別檢查,得自負後續意外的處理就是了(畢竟預設行為有做型別檢查啊)。

zzssc061u iT邦新手 5 級 ‧ 2022-12-18 11:58:29 檢舉

透過自訂collate_fn 可以排除我的疑問,已更了解相關機制,感謝。

0
增廣建文
iT邦研究生 5 級 ‧ 2022-12-06 21:02:07

DatasetBatch內是有你要的index的,可以用batch.store找到

zzssc061u iT邦新手 5 級 ‧ 2022-12-10 19:56:36 檢舉

batch 還沒有建立成功,就出現異常訊息了 /images/emoticon/emoticon02.gif

增廣建文 iT邦研究生 5 級 ‧ 2022-12-11 01:37:28 檢舉

有阿 你上一版code是可以執行到print(batch)

增廣建文 iT邦研究生 5 級 ‧ 2022-12-11 01:52:45 檢舉

新版最簡單的解法就是在max前面加上if去判斷不是None
batch會拿到SketchDataBatch(x=[3, 1], edge_index=[2, 4], stroke_idx=[1], stroke_num=[1], batch=[3], ptr=[2])

我要發表回答

立即登入回答