我從網路上抓了一個text-detection的專案來用自己的train data訓練
不過其中遇到了cpp的問題,該段程式碼如下
他在from .pse import pse_cpp, get_points, get_num會跳出下面的錯誤訊息
from .pse import pse_cpp, get_points, get_num ModuleNotFoundError: No module named 'post_processing.pse'
此外該資料夾如下
另外我所抓的專案github為:https://github.com/WenmuZhou/PAN.pytorch
def decode(preds, scale=1, threshold=0.7311, min_area=5):
"""
在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分
:param preds: 网络输出
:param scale: 网络的scale
:param threshold: sigmoid的阈值
:return: 最后的输出图和文本框
"""
from .pse import pse_cpp, get_points, get_num
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
preds = preds.detach().cpu().numpy()
score = preds[0].astype(np.float32)
text = preds[0] > threshold # text
kernel = (preds[1] > threshold) * text # kernel
similarity_vectors = preds[2:].transpose((1, 2, 0))
label_num, label = cv2.connectedComponents(kernel.astype(np.uint8), connectivity=4)
label_values = []
label_sum = get_num(label, label_num)
for label_idx in range(1, label_num):
if label_sum[label_idx] < min_area:
continue
label_values.append(label_idx)
pred = pse_cpp(text.astype(np.uint8), similarity_vectors, label, label_num, 0.8)
pred = pred.reshape(text.shape)
bbox_list = []
label_points = get_points(pred, score, label_num)
for label_value, label_point in label_points.items():
if label_value not in label_values:
continue
score_i = label_point[0]
label_point = label_point[2:]
points = np.array(label_point, dtype=int).reshape(-1, 2)
if points.shape[0] < 100 / (scale * scale):
continue
if score_i < 0.93:
continue
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect)
bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]])
return pred, np.array(bbox_list)