狀況描述:
建立模擬資料data_set時processed_data放入兩筆SketchData(stroke_idx均為9),再執行next()時先進入__len__再進入__getitem__,這個時候有個神奇的SketchDatasetBatch出現,到SketchData的__init__中,且所有參數都為空值(stroke_idx也變成None)。
導致後續max(stroke_idx)跳出異常訊息。
很不明白為何在放入processed_data時候明明有資料,再進入到__init__卻都是None
若有人知道在懇請告知,感恩讚嘆您
版本資訊:
錯誤訊息:
程式碼:
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import torch
class SketchData(Data):
def __init__(self, stroke_idx=None, x=None, edge_index=None, edge_attr=None, y=None,
pos=None, norm=None, face=None, **kwargs):
super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
self.stroke_idx = stroke_idx
self.stroke_num = max(stroke_idx) + 1
class SketchDataset(torch.utils.data.Dataset):
def __init__(self):
# 節點特徵
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
self.processed_data = []
sketch_data = SketchData(x=x,
edge_index=edge_index.t().contiguous(),
stroke_idx=[2,4,6,8])
self.processed_data.append(sketch_data)
self.processed_data.append(sketch_data)
def __len__(self):
return len(self.processed_data)
def __getitem__(self, index):
print(type(self.processed_data[index]))
return self.processed_data[index]
data_set = SketchDataset()
data_loader = DataLoader(data_set, batch_size=1, shuffle=False)
batch = next(iter(data_loader))
print(batch)
我不認識 pytroch,但因為你用到一個語法:
dataiter = iter(data_loader)
batch = dataiter.next()
讓我有點好奇而去查了下,感覺你的問題可能出在 "torch_geometric.loader.dynamic_batch_sampler" 的 def __iter__
部分。
後續我就懶了,你應該能自行判斷、處理。
另,我看到 "Advanced Mini-Batching"
這裡是用 batch = next(iter(loader))
請注意差異。
還有,請善用 markdown 語法與[預覽]功能,因為我很懶得猜原文是啥鬼~
(如果你是用手機就當我多言~)
感謝告知,已修正next與markdown語法。
後續有嘗試往自己建立Sampler方式,但是我本機沒有辦法from到,卡住了。
https://pytorch-geometric.readthedocs.io/en/latest/modules/sampler.html
參考練習下方連結自訂Data使用[]包起來next(iter(loader))取值,是沒有問題的。
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
但如果使用自訂Data(或是Object)使用Dataset包起來,在Batch時經過Sampler就會無法辨識自訂的class。
錯誤訊息
Exception has occurred: TypeError
DataLoader found invalid type: <class '__main__.MyData'>
程式碼
import torch
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
class MyData(object):
def __init__(self,num) -> None:
self.num =num
class MyDataset(Dataset):
def __init__(self):
self.data = []
for x in range(10):
self.data.append(MyData(x))
def __getitem__(self, index):
print(index)
return MyData(self.data[index].num)
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(dataset)
for _,data in enumerate(loader):
print(data)
我用 Google:"pytorch" "DataLoader found invalid type" site:pytorch-geometric.readthedocs.io/en/latest
時,
在這裡看到 DataLoader
在初始化時 ,collate_fn
參數的預設值是 Collater
的實例,而 Collater
初始化時會確認 batch 的首元素之類別;若找無適當對應則觸發型別錯誤例外。
所以你想要讓自定義的 MyData
通過 DataLoader
的初始化,大概能從 collate_fn
下手,不然就是另尋資料綁定等其他方法了。
但是,設定 collate_fn
參數規避型別檢查,得自負後續意外的處理就是了(畢竟預設行為有做型別檢查啊)。
透過自訂collate_fn 可以排除我的疑問,已更了解相關機制,感謝。
DatasetBatch內是有你要的index的,可以用batch.store
找到