FIERY 后处理原理与代码解读
本文主要是对FIERY原本Python版本的后处理代码进行介绍。此外,可以在FIERY 后处理仓库查看C++版本的开源代码,仅供学习交流使用。
FIERY的后处理部分主要是将decoder输出的各个head的结果组合起来,通过匹配操作得到每个实例的轨迹信息和预测结果。具体来说,先分别计算每一帧中的实例分割结果和实例中心点。然后,根据多帧的结果进行帧间的匹配工作,保持同一个instance在不同帧的id保持一致。最后得到同一个实例的连续运动信息和轨迹预测结果。整个序列一共是5帧,包括2帧过去,1帧当前,2帧未来的预测。
前景mask
根据segmentation计算前景mask
preds = output['segmentation'].detach() # 1*5*2*200*200
preds = torch.argmax(preds, dim=2, keepdims=True)
foreground_masks = preds.squeeze(2) == vehicles_id # 1*5*200*200
实例分割和实例中心点
根据instance_center和instance_offset计算实例分割结果和各实例中心点
batch_size, seq_len = preds.shape[:2]
pred_inst = []
for b in range(batch_size):
pred_inst_batch = []
for t in range(seq_len): # 逐帧处理
pred_instance_t, _ = get_instance_segmentation_and_centers(
output['instance_center'][b, t].detach(),
output['instance_offset'][b, t].detach(),
foreground_masks[b, t].detach()
)
pred_inst_batch.append(pred_instance_t)
pred_inst.append(torch.stack(pred_inst_batch, dim=0))
pred_inst = torch.stack(pred_inst).squeeze(2)
其中,get_instance_segmentation_and_centers函数的实现细节如下
def get_instance_segmentation_and_centers(
center_predictions: torch.Tensor,
offset_predictions: torch.Tensor,
foreground_mask: torch.Tensor,
conf_threshold: float = 0.1,
nms_kernel_size: float = 3,
max_n_instance_centers: int = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
width, height = center_predictions.shape[-2:]
center_predictions = center_predictions.view(1, width, height)
offset_predictions = offset_predictions.view(2, width, height)
foreground_mask = foreground_mask.view(1, width, height)
# find_instance_centers是使用最大池化实现的NMS,计算得到各个车辆的center坐标
centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size)
if not len(centers):
return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \
torch.zeros((0, 2), device=centers.device)
if len(centers) > max_n_instance_centers:
print(f'There are a lot of detected instance centers: {centers.shape}')
centers = centers[:max_n_instance_centers].clone()
# offset_predictions显示的是指示实例中心的vector field
# 根据其与各个中心点之间的距离,计算得到200*200的网格点各自属于哪一个实例
instance_ids = group_pixels(centers, offset_predictions)
# 通过之前计算过的前景mask进行过滤
instance_seg = (instance_ids * foreground_mask.float()).long()
# 上一步操作可能导致id编号中断,这个步骤只是重新赋值了id
instance_seg = make_instance_seg_consecutive(instance_seg)
return instance_seg.long(), centers
保持多帧instance id的一致性
上一步中,是针对每一帧单独处理的,同一个实例在不同帧的id不一定相同。这部分通过进行多帧之间的匹配,根据instance_flow提供的信息统一多帧的实例信息,从而得到每个实例的时序运动信息。
consistent_instance_seg = []
for b in range(batch_size):
consistent_instance_seg.append(
# 根据instance_flow提供的信息统一多帧的实例id
make_instance_id_temporally_consistent(pred_inst[b:b+1],
output['instance_flow'][b:b+1].detach())
)
consistent_instance_seg = torch.cat(consistent_instance_seg, dim=0)
其中,make_instance_id_temporally_consistent函数的实现细节
def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0):
"""
Parameters
pred_inst: torch.Tensor (1, seq_len, h, w)
future_flow: torch.Tensor(1, seq_len, 2, h, w)
matching_threshold: distance threshold for a match to be valid.
Returns
consistent_instance_seg: torch.Tensor(1, seq_len, h, w)
1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1.
2. Store those centers
3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers.
Make the labels at t+1 consistent with the matching
4. Repeat
"""
assert pred_inst.shape[0] == 1, 'Assumes batch size = 1'
# Initialise instance segmentations with prediction corresponding to the present
consistent_instance_seg = [pred_inst[0, 0]]
largest_instance_id = consistent_instance_seg[0].max().item()
_, seq_len, h, w = pred_inst.shape
device = pred_inst.device
for t in range(seq_len - 1):
# Compute predicted future instance means
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
))
# 预定义的网格点与预测的future flow相加
grid = grid + future_flow[0, t] # [2, 200, 200] 分别是两个轴的坐标
warped_centers = []
# 获取背景之外的所有id号
t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy()
if len(t_instance_ids) == 0:
# No instance so nothing to update
consistent_instance_seg.append(pred_inst[0, t + 1])
continue
for instance_id in t_instance_ids:
instance_mask = (consistent_instance_seg[-1] == instance_id)
# 使用instance mask的均值作为中心点坐标
warped_centers.append(grid[:, instance_mask].mean(dim=1))
warped_centers = torch.stack(warped_centers)
# Compute actual future instance means
centers = []
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
))
n_instances = int(pred_inst[0, t + 1].max().item())
if n_instances == 0:
# No instance, so nothing to update.
consistent_instance_seg.append(pred_inst[0, t + 1])
continue
for instance_id in range(1, n_instances + 1):
instance_mask = (pred_inst[0, t + 1] == instance_id)
centers.append(grid[:, instance_mask].mean(dim=1))
centers = torch.stack(centers)
# Compute distance matrix between warped centers and actual centers
distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy()
# outputs (row, col) with row: index in frame t, col: index in frame t+1
# the missing ids in col must be added (correspond to new instances)
# 匈牙利匹配,计算当前帧与下一帧之间的匹配关系
ids_t, ids_t_one = linear_sum_assignment(distances)
matching_distances = distances[ids_t, ids_t_one]
# Offset by one as id=0 is the background
ids_t += 1
ids_t_one += 1
# swap ids_t with real ids. as those ids correspond to the position in the distance matrix.
id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids))
ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t)
# Filter low quality match
ids_t = ids_t[matching_distances < matching_threshold]
ids_t_one = ids_t_one[matching_distances < matching_threshold]
# Elements that are in t+1, but weren't matched
remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one))
# remove background
remaining_ids.remove(0)
# Set remaining_ids to a new unique id
for remaining_id in list(remaining_ids):
largest_instance_id += 1
ids_t = np.append(ids_t, largest_instance_id)
ids_t_one = np.append(ids_t_one, remaining_id)
consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t))
consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0)
return consistent_instance_seg
计算每个instance的轨迹
assert batch_size == 1
# Generate trajectories
matched_centers = {}
_, seq_len, h, w = consistent_instance_seg.shape
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=preds.device),
torch.arange(w, dtype=torch.float, device=preds.device)
))
for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy():
# 得到每个instance id在各帧的位置
for t in range(seq_len):
instance_mask = consistent_instance_seg[0, t] == instance_id
if instance_mask.sum() > 0:
matched_centers[instance_id] = matched_centers.get(instance_id, []) + [
grid[:, instance_mask].mean(dim=-1)]
for key, value in matched_centers.items():
matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1]
return consistent_instance_seg, matched_centers
得到上述结果之后,进行简单的坐标系转换就可以得到车载坐标系下各个障碍物的位置和轨迹预测结果,也可以进一步计算得到运动速度、加速度等信息传递给下游模块。