延續前一篇文介紹 TUM RGB-D Dataset,本文將要來使用python處理這個數據集,首先要建立一個 Python 的 dataset class,以物件的方式來處理。
利用 rgb.txt 和 depth.txt 來讀取 RGB 和深度圖像,groundtruth.txt 來讀取相機的姿態資訊,每一幀都有一個時間戳 (timestamp)。值得注意的是,這些資料並不是一一對應的,舉例來說 RGB 圖像有 1311868164.363181.png,但在 depth 或是 groundtruth 中可能沒有這個時間戳,所以我們要把 timestamp 利用 parse_timestamp 轉換成整數存下來,等等使用。
from pathlib import Path
import numpy as np
from PIL import Image
from scipy.spatial.transform import Rotation
def parse_timestamp(timestamp):
    # Keep only the first 3 digits after the dot
    # Example: 1311868263.185529 -> 1311868263 * 100 + 185
    return int(timestamp.split(".")[0]) * 1000 + int(timestamp.split(".")[1][:3])
class TUMRGBD:
    def __init__(
        self,
        root_dir,
    ):
        root_dir = Path(root_dir)
        self.root_dir = root_dir
        self.rgb_dir = root_dir / "rgb"
        self.depth_dir = root_dir / "depth"
        self.rgb_file = root_dir / "rgb.txt"
        self.depth_file = root_dir / "depth.txt"
        self.groundtruth_file = root_dir / "groundtruth.txt"
        self.rgb_files = []
        self.rgb_timestamps = []
        with open(self.rgb_file) as f:
            # Ignore the first three lines
            for _ in range(3):
                f.readline()
                
            for line in f:
                parts = line.strip().split(" ")
                timestamp = parse_timestamp(parts[0])
                self.rgb_timestamps.append(timestamp)
                self.rgb_files.append(parts[1].split("/")[-1])
        print(f"Loaded {len(self.rgb_timestamps)} RGB frames")
        
        self.depth_files = []
        self.depth_timestamps = []
        with open(self.depth_file) as f:
            # Ignore the first three lines
            for _ in range(3):
                f.readline()
                
            for line in f:
                parts = line.strip().split(" ")
                timestamp = parse_timestamp(parts[0])
                self.depth_timestamps.append(timestamp)
                self.depth_files.append(parts[1].split("/")[-1])
        
        print(f"Loaded {len(self.depth_timestamps)} depth frames")
        
        self.pose_timestamps = []
        self.translations = []
        self.rotations = []
        with open(self.groundtruth_file) as f:
            # Ignore the first three lines
            for _ in range(3):
                f.readline()
                
            self.groundtruth = f.readlines()
            
        for line in self.groundtruth:
            parts = line.split(" ")
            timestamp = parse_timestamp(parts[0])
            self.pose_timestamps.append(timestamp)
            self.translations.append([float(x) for x in parts[1:4]])
            # The rotation is in quaternion format
            qx, qy, qz, qw = [float(x) for x in parts[4:8]]
            R = Rotation.from_quat([qx, qy, qz, qw])
            self.rotations.append(R.as_matrix())
        print(f"Loaded {len(self.pose_timestamps)} poses")
接著我們利用 python 內建的 __len__ 和 __getitem__ 來實作 dataset class 的讀取,讓我們可以使用 len(dataset) 得到數據集大小(這裡是 rgb 圖像的數量),並且利用 dataset[x] 來讀取該幀的相關訊息。
在 __getitem__ 裡,我們用二分搜尋來找到最接近的 timestamp,並且檢查是否在時間容忍範圍內,如果不在容忍範圍內,我們會回傳 None,這樣在使用時就可以直接跳過這一幀。
    def __len__(self):
        return len(self.rgb_timestamps)
    def __getitem__(self, idx: int):
        # Based on rgb
        timestamp = self.rgb_timestamps[idx]
        rgb_file = self.rgb_files[idx]
        # Find the corresponding depth frame with the closest timestamp using binary search
        depth_idx = np.searchsorted(self.depth_timestamps, timestamp)
        depth_file = self.depth_files[depth_idx]
        # Find the corresponding pose with the closest timestamp using binary search
        pose_idx = np.searchsorted(self.pose_timestamps, timestamp)
        R = self.rotations[pose_idx]
        t = self.translations[pose_idx]
        
        MAX_TIME_TOLERANCE = 100
        if (
            abs(timestamp - self.pose_timestamps[pose_idx]) > MAX_TIME_TOLERANCE
            or abs(timestamp - self.depth_timestamps[depth_idx]) > MAX_TIME_TOLERANCE
        ):
            # We cannot find the corresponding pose within the time tolerance. Return None
            return None
        
        return {
            "rgb_path": self.rgb_dir / rgb_file,
            "depth_path": self.depth_dir / depth_file,
            "rotation": R,
            "translation": t,
        }
def main():
    dataset = TUMRGBD("data/rgbd_dataset_freiburg2_desk")
    print("Dataset size", len(dataset))
    x = dataset[0]
    if x is None:
        return
    print(x.keys())
    
if __name__ == "__main__":
    main()
會得到以下輸出:
Loaded 2965 RGB frames
Loaded 2964 depth frames
Loaded 20957 poses
Dataset size: 2965
dict_keys(['rgb_path', 'depth_path', 'rotation', 'translation'])