在重新開始執行pretrain.py
前,針對先前的亂碼問題,本來我以為問題是出在tokenizer,不過仔細看了一下印出來的亂碼後,我回去比對了medical_book_zh.json
的原始文本後,發現我印出的亂碼是源自於這個檔案,而且裡面其實亂碼不少;然而目前我沒有想到特別好的辦法去自動的篩選出亂碼,因為目前我打算還是繼續使用這些文本的資料;鑑於原作者一樣是使用相同的資料與網路架構訓練出正常的model,所以理論上就算不去過濾這些資料應該還是可以訓練出一個相對正常的模型。
不過在pretrain.py
中,我發現了一個原本沒注意到的問題,在產生PretrainDataset
的時候,程式碼中原本是設定memmap=True
的
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len,memmap=True)
一旦memmap=True
時,PretrainDataset
這邊hardcode只會使用data_path_list[0]
這個檔案,然而我在先前訓練時也是採用輸入多個datapath的方式,這導致了實際上我只使用到了data_path_list
中的第一個檔案,而剛好第一個檔案非常的小
class PretrainDataset(Dataset):
def __init__(self,data_path_lst,max_length=256,memmap=False):
......
if memmap:
with open(data_path_lst[0],'r') as f:
nbytes = f.seek(0,2)
flen = f.tell() // np.dtype('uint16').itemsize
self.data = np.memmap(data_path_lst[0],dtype=np.dtype('uint16'),shape=(flen//max_length,max_length))
......
要避開這個問題,可以
data_process.py
中把所有的.bin
檔融合成一個'./data/pretrain_data.bin',並使用這一個檔案來做pretrain
memmap=False
由於我的RAM不太大,因此我這邊直接設定memmap=False
來避過這個問題,之後重新執行pretrain.py
,訓練資料大小為memmap:False train data.shape:(8336256, 512)
。
if __name__=="__main__":
......
max_epoch = 2
......
batch_size = 32
max_seq_len = 512
dim = 512
n_layers = 8
n_heads = 8
multiple_of = 32
......
#-----init dataloader------
data_path_list=[
# './data/pretrain_data.bin',
'./data/medical_book.bin',
'./data/medical_encyclopedia.bin',
'./data/wiki.bin',
'./data/baidubaike_563w.bin',
]
train_ds = PretrainDataset(data_path_list, max_length=max_seq_len,memmap=False)