FIERY 后处理原理与代码解读
本文主要是对FIERY原本Python版本的后处理代码进行介绍。此外,可以在FIERY 后处理仓库查看C++版本的开源代码,仅供学习交流使用。
preds = output['segmentation'].detach() # 1*5*2*200*200
preds = torch.argmax(preds, dim=2, keepdims=True)
foreground_masks = preds.squeeze(2) == vehicles_id # 1*5*200*200
batch_size, seq_len = preds.shape[:2]
pred_inst = []
for b in range(batch_size):
pred_inst_batch = []
for t in range(seq_len): # 逐帧处理
pred_instance_t, _ = get_instance_segmentation_and_centers(
output['instance_center'][b, t].detach(),
output['instance_offset'][b, t].detach(),
foreground_masks[b, t].detach()
pred_inst.append(torch.stack(pred_inst_batch, dim=0))
pred_inst = torch.stack(pred_inst).squeeze(2)
def get_instance_segmentation_and_centers(
center_predictions: torch.Tensor,
offset_predictions: torch.Tensor,
foreground_mask: torch.Tensor,
conf_threshold: float = 0.1,
nms_kernel_size: float = 3,
max_n_instance_centers: int = 100,
) -> Tuple[torch.Tensor, torch.Tensor]:
width, height = center_predictions.shape[-2:]
center_predictions = center_predictions.view(1, width, height)
offset_predictions = offset_predictions.view(2, width, height)
foreground_mask = foreground_mask.view(1, width, height)
# find_instance_centers是使用最大池化实现的NMS,计算得到各个车辆的center坐标
centers = find_instance_centers(center_predictions, conf_threshold=conf_threshold, nms_kernel_size=nms_kernel_size)
if not len(centers):
return torch.zeros(center_predictions.shape, dtype=torch.int64, device=center_predictions.device), \
torch.zeros((0, 2), device=centers.device)
if len(centers) > max_n_instance_centers:
print(f'There are a lot of detected instance centers: {centers.shape}')
centers = centers[:max_n_instance_centers].clone()
# offset_predictions显示的是指示实例中心的vector field
# 根据其与各个中心点之间的距离,计算得到200*200的网格点各自属于哪一个实例
instance_ids = group_pixels(centers, offset_predictions)
# 通过之前计算过的前景mask进行过滤
instance_seg = (instance_ids * foreground_mask.float()).long()
# 上一步操作可能导致id编号中断,这个步骤只是重新赋值了id
instance_seg = make_instance_seg_consecutive(instance_seg)
return instance_seg.long(), centers
保持多帧instance id的一致性
consistent_instance_seg = []
for b in range(batch_size):
# 根据instance_flow提供的信息统一多帧的实例id
consistent_instance_seg =, dim=0)
def make_instance_id_temporally_consistent(pred_inst, future_flow, matching_threshold=3.0):
pred_inst: torch.Tensor (1, seq_len, h, w)
future_flow: torch.Tensor(1, seq_len, 2, h, w)
matching_threshold: distance threshold for a match to be valid.
consistent_instance_seg: torch.Tensor(1, seq_len, h, w)
1. time t. Loop over all detected instances. Use flow to compute new centers at time t+1.
2. Store those centers
3. time t+1. Re-identify instances by comparing position of actual centers, and flow-warped centers.
Make the labels at t+1 consistent with the matching
4. Repeat
assert pred_inst.shape[0] == 1, 'Assumes batch size = 1'
# Initialise instance segmentations with prediction corresponding to the present
consistent_instance_seg = [pred_inst[0, 0]]
largest_instance_id = consistent_instance_seg[0].max().item()
_, seq_len, h, w = pred_inst.shape
device = pred_inst.device
for t in range(seq_len - 1):
# Compute predicted future instance means
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
# 预定义的网格点与预测的future flow相加
grid = grid + future_flow[0, t] # [2, 200, 200] 分别是两个轴的坐标
warped_centers = []
# 获取背景之外的所有id号
t_instance_ids = torch.unique(consistent_instance_seg[-1])[1:].cpu().numpy()
if len(t_instance_ids) == 0:
# No instance so nothing to update
consistent_instance_seg.append(pred_inst[0, t + 1])
for instance_id in t_instance_ids:
instance_mask = (consistent_instance_seg[-1] == instance_id)
# 使用instance mask的均值作为中心点坐标
warped_centers.append(grid[:, instance_mask].mean(dim=1))
warped_centers = torch.stack(warped_centers)
# Compute actual future instance means
centers = []
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=device), torch.arange(w, dtype=torch.float, device=device)
n_instances = int(pred_inst[0, t + 1].max().item())
if n_instances == 0:
# No instance, so nothing to update.
consistent_instance_seg.append(pred_inst[0, t + 1])
for instance_id in range(1, n_instances + 1):
instance_mask = (pred_inst[0, t + 1] == instance_id)
centers.append(grid[:, instance_mask].mean(dim=1))
centers = torch.stack(centers)
# Compute distance matrix between warped centers and actual centers
distances = torch.norm(centers.unsqueeze(0) - warped_centers.unsqueeze(1), dim=-1).cpu().numpy()
# outputs (row, col) with row: index in frame t, col: index in frame t+1
# the missing ids in col must be added (correspond to new instances)
# 匈牙利匹配,计算当前帧与下一帧之间的匹配关系
ids_t, ids_t_one = linear_sum_assignment(distances)
matching_distances = distances[ids_t, ids_t_one]
# Offset by one as id=0 is the background
ids_t += 1
ids_t_one += 1
# swap ids_t with real ids. as those ids correspond to the position in the distance matrix.
id_mapping = dict(zip(np.arange(1, len(t_instance_ids) + 1), t_instance_ids))
ids_t = np.vectorize(id_mapping.__getitem__, otypes=[np.int64])(ids_t)
# Filter low quality match
ids_t = ids_t[matching_distances < matching_threshold]
ids_t_one = ids_t_one[matching_distances < matching_threshold]
# Elements that are in t+1, but weren't matched
remaining_ids = set(torch.unique(pred_inst[0, t + 1]).cpu().numpy()).difference(set(ids_t_one))
# remove background
# Set remaining_ids to a new unique id
for remaining_id in list(remaining_ids):
largest_instance_id += 1
ids_t = np.append(ids_t, largest_instance_id)
ids_t_one = np.append(ids_t_one, remaining_id)
consistent_instance_seg.append(update_instance_ids(pred_inst[0, t + 1], old_ids=ids_t_one, new_ids=ids_t))
consistent_instance_seg = torch.stack(consistent_instance_seg).unsqueeze(0)
return consistent_instance_seg
assert batch_size == 1
# Generate trajectories
matched_centers = {}
_, seq_len, h, w = consistent_instance_seg.shape
grid = torch.stack(torch.meshgrid(
torch.arange(h, dtype=torch.float, device=preds.device),
torch.arange(w, dtype=torch.float, device=preds.device)
for instance_id in torch.unique(consistent_instance_seg[0, 0])[1:].cpu().numpy():
# 得到每个instance id在各帧的位置
for t in range(seq_len):
instance_mask = consistent_instance_seg[0, t] == instance_id
if instance_mask.sum() > 0:
matched_centers[instance_id] = matched_centers.get(instance_id, []) + [
grid[:, instance_mask].mean(dim=-1)]
for key, value in matched_centers.items():
matched_centers[key] = torch.stack(value).cpu().numpy()[:, ::-1]
return consistent_instance_seg, matched_centers