a subject <predictate> object
。例如, person riding bicycle
, “person” 和 “bicycle” 分別是主詞和受詞, “riding” 是關係動詞。model.py
遇到pandas.DataFrame.as_matrix()
是舊的語法,也提供修正方式pandas.DataFrame.values
。我們現在編寫標記函數來檢測邊界框對之間存在什麼關係。為此,我們可以將各種直覺編碼到標記函數中:
分類直覺:關於這些關係中通常涉及的主詞和受詞類別的知識(例如,person通常是謂詞 RIDE 和的主詞 CARRY)
空間直覺:關於主詞和受詞的相對位置的知識(例如,主詞通常高於動詞的受詞RIDE)
RIDE = 0
CARRY = 1
OTHER = 2
ABSTAIN = -1
我們從編碼分類直覺的標記函數開始:我們使用關於共同的主題-客體類別對的知識 RIDE,CARRY 以及關於哪些主題或客體不太可能涉及這兩種關係的知識。
from snorkel.labeling import labeling_function
# Category-based LFs
@labeling_function()
def lf_ride_object(x):
if x.subject_category == "person":
if x.object_category in [
"bike",
"snowboard",
"motorcycle",
"horse",
"bus",
"truck",
"elephant",
]:
return RIDE
return ABSTAIN
@labeling_function()
def lf_carry_object(x):
if x.subject_category == "person":
if x.object_category in ["bag", "surfboard", "skis"]:
return CARRY
return ABSTAIN
@labeling_function()
def lf_carry_subject(x):
if x.object_category == "person":
if x.subject_category in ["chair", "bike", "snowboard", "motorcycle", "horse"]:
return CARRY
return ABSTAIN
@labeling_function()
def lf_not_person(x):
if x.subject_category != "person":
return OTHER
return ABSTAIN
現在編碼空間直覺,其中包括測量邊界框之間的距離並比較它們的相對區域。
YMIN = 0
YMAX = 1
XMIN = 2
XMAX = 3
import numpy as np
# Distance-based LFs
@labeling_function()
def lf_ydist(x):
if x.subject_bbox[XMAX] < x.object_bbox[XMAX]:
return OTHER
return ABSTAIN
@labeling_function()
def lf_dist(x):
if np.linalg.norm(np.array(x.subject_bbox) - np.array(x.object_bbox)) <= 1000:
return OTHER
return ABSTAIN
def area(bbox):
return (bbox[YMAX] - bbox[YMIN]) * (bbox[XMAX] - bbox[XMIN])
# Size-based LF
@labeling_function()
def lf_area(x):
if area(x.subject_bbox) / area(x.object_bbox) <= 0.5:
return OTHER
return ABSTAIN
標記函數具有不同的經驗準確性和覆蓋範圍。由於我們選擇的關係中的類別不平衡,標記 OTHER 的標記函數比RIDE或CARRY的標記函數具有更高的覆蓋率。這也反映了數據集中類的分佈。
訓練 LabelModel
來為未標記的訓練集分配訓練標籤。
from snorkel.labeling.model import LabelModel
label_model = LabelModel(cardinality=3, verbose=True)
label_model.fit(
L_train,
seed=123,
lr=0.01,
log_freq=10,
n_epochs=100
)
現在,您可以使用這些訓練標籤來訓練任何標準判別模型,例如現成的 ResNet,它應該學會在我們開發的 LF 之外進行泛化。
from snorkel.classification import DictDataLoader
from model import SceneGraphDataset, create_model
df_train["labels"] = label_model.predict(L_train)
if sample:
TRAIN_DIR = "data/VRD/sg_dataset/samples"
else:
TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images"
dl_train = DictDataLoader(
SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train),
batch_size=16,
shuffle=True,
)
dl_valid = DictDataLoader(
SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid),
batch_size=16,
shuffle=False,
)
定義模型架構。
import torchvision.models as models
# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)
from snorkel.classification import Trainer
trainer = Trainer(
n_epochs=1, # increase for improved performance
lr=1e-3,
checkpointing=True,
checkpointer_config={"checkpoint_dir": "checkpoint"},
)
trainer.fit(model, [dl_train])
model.score([dl_valid])
# {'visual_relation_task/valid_dataset/valid/f1_micro':
# 0.34615384615384615}
我們已經成功訓練了一個視覺關係檢測模型!使用關於視覺關係中的對像如何相互作用的分類和空間直覺,我們能夠在多類分類設置中為 VRD 數據集中的對像對分配高質量的訓練標籤。
有關 Snorkel 如何用於視覺關係任務的更多信息,請參閱該團隊 ICCV 2019 論文。