FIERY 后处理原理与代码解读


FIERY 后处理原理与代码解读

本文主要是对FIERY原本Python版本的后处理代码进行介绍。此外,可以在FIERY 后处理仓库查看C++版本的开源代码,仅供学习交流使用。

FIERY的后处理部分主要是将decoder输出的各个head的结果组合起来,通过匹配操作得到每个实例的轨迹信息和预测结果。具体来说,先分别计算每一帧中的实例分割结果和实例中心点。然后,根据多帧的结果进行帧间的匹配工作,保持同一个instance在不同帧的id保持一致。最后得到同一个实例的连续运动信息和轨迹预测结果。整个序列一共是5帧,包括2帧过去,1帧当前,2帧未来的预测。

前景mask

根据segmentation计算前景mask

preds = output['segmentation'].detach()  # 1*5*2*200*200
preds = torch.argmax(preds, dim=2, keepdims=True)
foreground_masks = preds.squeeze(2) == vehicles_id  # 1*5*200*200

实例分割和实例中心点

根据instance_centerinstance_offset计算实例分割结果和各实例中心点

batch_size, seq_len = preds.shape[:2]
pred_inst = []
for b in range(batch_size):
    pred_inst_batch = []
    for t in range(seq_len):  # 逐帧处理
        pred_instance_t, _ = get_instance_segmentation_and_centers(
            output['instance_center'][b, t].detach(),
            output['instance_offset'][b, t].detach(),
            foreground_masks[b, t].detach()
        )
        pred_inst_batch.append(pred_instance_t)
    pred_inst.append(torch.stack(pred_inst_batch, dim=0))
pred_inst = torch.stack(pred_inst).squeeze(2)

其中,get_instance_segmentation_and_centers函数的实现细节如下

def get_instance_segmentation_and_centers(
    center_predictions: torch.Tensor,
    offset_predictions: torch.Tensor,
    foreground_mask: torch.Tensor,
    conf_threshold: float = 0.1,
    nms_kernel_size: float = 3,
    max_n_instance_centers: int = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
    width, height = center_predictions.shape[-2:]
    center_predictions = center_predictions.view(1, width, height)
    offset_predictions = offset_predictions.view(2, width, height)
    foreground_mask = foreground_mask.view(1, width, height)

    # find_instance_centers是使用最大池化实现的NMS,计算得到各个车辆的center坐标
    centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size)
    if not len(centers):
        return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \
               torch.zeros((0, 2), device=centers.device)

    if len(centers) > max_n_instance_centers:
        print(f'There are a lot of detected instance centers: {centers.shape}')
        centers = centers[:max_n_instance_centers].clone()

    # offset_predictions显示的是指示实例中心的vector field
    # 根据其与各个中心点之间的距离,计算得到200*200的网格点各自属于哪一个实例
    instance_ids = group_pixels(centers, offset_predictions)
    # 通过之前计算过的前景mask进行过滤
    instance_seg = (instance_ids * foreground_mask.float()).long()

    # 上一步操作可能导致id编号中断,这个步骤只是重新赋值了id
    instance_seg = make_instance_seg_consecutive(instance_seg)

    return instance_seg.long(), centers

保持多帧instance id的一致性

上一步中,是针对每一帧单独处理的,同一个实例在不同帧的id不一定相同。这部分通过进行多帧之间的匹配,根据instance_flow提供的信息统一多帧的实例信息,从而得到每个实例的时序运动信息。

consistent_instance_seg = []
for b in range(batch_size):
    consistent_instance_seg.append(
        # 根据instance_flow提供的信息统一多帧的实例id
        make_instance_id_temporally_consistent(pred_inst[b:b+1],
                                                output['instance_flow'][b:b+1].detach())
    )
consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0)

其中,make_instance_id_temporally_consistent函数的实现细节

def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0):
    """
    Parameters
        pred_inst: torch.Tensor (1, seq_len, h, w)
        future_flow: torch.Tensor(1, seq_len, 2, h, w)
        matching_threshold: distance threshold for a match to be valid.

    Returns
    consistent_instance_seg: torch.Tensor(1, seq_len, h, w)

    1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1.
    2. Store those centers
    3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers.
        Make the labels at t+1 consistent with the matching
    4. Repeat
    """
    assert pred_inst.shape[0] == 1, 'Assumes batch size = 1'

    # Initialise instance segmentations with prediction corresponding to the present
    consistent_instance_seg = [pred_inst[0, 0]]
    largest_instance_id = consistent_instance_seg[0].max().item()

    _, seq_len, h, w = pred_inst.shape
    device = pred_inst.device
    for t in range(seq_len - 1):
        # Compute predicted future instance means
        grid = torch.stack(torch.meshgrid(
            torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
        ))

        # 预定义的网格点与预测的future flow相加
        grid = grid + future_flow[0, t]  # [2, 200, 200] 分别是两个轴的坐标
        warped_centers = []
        # 获取背景之外的所有id号
        t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy()

        if len(t_instance_ids) == 0:
            # No instance so nothing to update
            consistent_instance_seg.append(pred_inst[0, t + 1])
            continue

        for instance_id in t_instance_ids:
            instance_mask = (consistent_instance_seg[-1] == instance_id)
            # 使用instance mask的均值作为中心点坐标
            warped_centers.append(grid[:, instance_mask].mean(dim=1))
        warped_centers = torch.stack(warped_centers)

        # Compute actual future instance means
        centers = []
        grid = torch.stack(torch.meshgrid(
            torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
        ))
        n_instances = int(pred_inst[0, t + 1].max().item())

        if n_instances == 0:
            # No instance, so nothing to update.
            consistent_instance_seg.append(pred_inst[0, t + 1])
            continue

        for instance_id in range(1, n_instances + 1):
            instance_mask = (pred_inst[0, t + 1] == instance_id)
            centers.append(grid[:, instance_mask].mean(dim=1))
        centers = torch.stack(centers)

        # Compute distance matrix between warped centers and actual centers
        distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy()
        # outputs (row, col) with row: index in frame t, col: index in frame t+1
        # the missing ids in col must be added (correspond to new instances)
        # 匈牙利匹配,计算当前帧与下一帧之间的匹配关系
        ids_t, ids_t_one = linear_sum_assignment(distances)
        matching_distances = distances[ids_t, ids_t_one]
        # Offset by one as id=0 is the background
        ids_t += 1
        ids_t_one += 1

        # swap ids_t with real ids. as those ids correspond to the position in the distance matrix.
        id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids))
        ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t)

        # Filter low quality match
        ids_t = ids_t[matching_distances < matching_threshold]
        ids_t_one = ids_t_one[matching_distances < matching_threshold]

        # Elements that are in t+1, but weren't matched
        remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one))
        # remove background
        remaining_ids.remove(0)
        #  Set remaining_ids to a new unique id
        for remaining_id in list(remaining_ids):
            largest_instance_id += 1
            ids_t = np.append(ids_t, largest_instance_id)
            ids_t_one = np.append(ids_t_one, remaining_id)

        consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t))

    consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0)
    return consistent_instance_seg

计算每个instance的轨迹

assert batch_size == 1
# Generate trajectories
matched_centers = {}
_, seq_len, h, w = consistent_instance_seg.shape
grid = torch.stack(torch.meshgrid(
    torch.arange(h, dtype=torch.float, device=preds.device),
    torch.arange(w, dtype=torch.float, device=preds.device)
))

for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy():
    # 得到每个instance id在各帧的位置
    for t in range(seq_len):
        instance_mask = consistent_instance_seg[0, t] == instance_id
        if instance_mask.sum() > 0:
            matched_centers[instance_id] = matched_centers.get(instance_id, []) + [
                grid[:, instance_mask].mean(dim=-1)]

for key, value in matched_centers.items():
    matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1]

return consistent_instance_seg, matched_centers

得到上述结果之后,进行简单的坐标系转换就可以得到车载坐标系下各个障碍物的位置和轨迹预测结果,也可以进一步计算得到运动速度、加速度等信息传递给下游模块。


文章作者: Jingyi Yu
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Jingyi Yu !
  目录