以上圖出自 Hugging Face 官方
範例資料集我就直接全部使用 Day21 使用到的 wikiann
的中文數據集
from datasets import load_dataset
wikiann_datasets = load_dataset('wikiann', 'zh', split='train')
wikiann_datasets['train']
wikiann_dataset[:3]['tokens']
Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 20000
})
[['2', '0', '0', '9', '年', ':', '李', '民', '基', '《', 'E', 't', 'e', 'r', 'n', 'a', 'l', '#', 'S', 'u', 'm', 'm', 'e', 'r', '》'],
['#', '澳', '門', '大', '學', '田', '家', '炳', '教', '育', '研', '究', '所'],
['#', '大', '维', '多', '利', '亚', '沙', '漠']]
dataset_shuffle = wikiann_dataset.shuffle(seed=42)
print(dataset_shuffle[:3]['tokens'])
[['#', '李', '相', '秀', '#', '朴', '英', '淑'],
["'", "'", "'", '威', '爾', '·', '普', '爾', '特', "'", "'", "'"],
["'", "'", "'", '丁', '戈', "'", "'", "'", '(', '配', '音', '員', ':', '山', '口', '真', '弓', '(', '日', '本', ')', ')']]
train
和test
,我們除了自己提前定義好之外也可以透過內建方法直接做分割data_split = wikiann_datasets.train_test_split(test_size=0.1)
print(data_split)
train_test_split
方法,給定要分割的比例 0.1
,也就是 9:1 的分割方式DatasetDict({
train: Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 18000
})
test: Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 2000
})
})
train data
和 2000 的 test data
train
、validation
、test
三種資料集data_split_2 = data_split["train"].train_test_split(train_size=0.8, seed=42)
data_split_2["validation"] = data_split_2.pop("test")
data_split_2["test"] = data_split["test"]
print(data_split_2)
train
的部分再拆成 train
和test
test
更改為validation
test
加到新的資料集DatasetDict({
train: Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 14400
})
validation: Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 3600
})
test: Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 2000
})
})
# 第一種
data_select = wikiann_datasets.select(range(1000))
print(data_select)
# 第二種
indices = [100, 200, 300, 400, 500]
data_select = wikiann_datasets.select(indices)
print(data_select)
Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 1000
})
Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 5
})
data = wikiann_datasets.filter(lambda x: len(x["tokens"]) > 20)
print(data)
(對於 lambda 想多了解的人可以再去自己找)
也可以寫成這樣
def longer(x):
return len(x["tokens"]) > 20
data = wikiann_datasets.filter(longer)
Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 4789
})
data = wikiann_datasets.filter(lambda x: x["tokens"][0] == '#')
print(data)
Dataset({
features: ['tokens', 'ner_tags', 'langs', 'spans'],
num_rows: 6566
})