MapTR论文与代码解读


本文对MapTR的论文和代码进行了解读。

本文主要改进的地方在于针对地图元素的数据结构进行了instance-point点集的层级式设计,模型直接预测了每个地图元素的点集,而非HDMapNet类的方法去预测分割的tensor。Nuscenes Map所提供的数据本身也就是点集的形式,以下将首先对Nuscenes Map数据集进行简要介绍。

Nuscenes Map

Nuscenes地图分为两个图层,几何图层和非几何图层,几何图层包括(polygon, line, node), 非几何图层包括(drivable_area, road_segment, road_block, lane, ped_crossing, walkway, stop_line, ‘carpark_area’, ‘road divider’, ‘lane divider’, ‘traffic light’)
下载Nuscenes Map expansion可以看到地图元素的细节描述,以下以expansion/boston-seaport.json为例进行具体分析

{
  "polygon": [
    {
      "token": "1b161e64-fe37-4f3f-96db-299edefe9f8c", 
      "exterior_node_tokens": [
        "ee85f1a6-38c4-4491-8a3b-93e1d5bf2911", 
        ...
      ], 
      "holes": []
    }
  ], 
  "line": [
    {
      "token": "7a8fcfed-9c66-475a-8dcf-9efe983acc98", 
      "node_tokens": [
        "fde59d81-e200-4a78-80e7-b8044d417b02",
        "0cffed4b-c5b7-4cd2-95ec-0d4252a1c896"
      ]
    }
  ], 
  "node": [
    {
      "token": "16af4f78-e195-4954-bf2a-889e9fa4d751",
      "x": 525.7599033618623,
      "y": 752.3132939140847
    }
  ], 
  ...
}

MapTR

We present Map TRansformer, for efficient online vectorized HD map construction. MapTR is a unified permutation-based modeling approach, i.e., modeling map element as a point set with a group of equivalent permutations, which avoids the definition ambiguity of map element and eases learning.

MapTR是一个基于排列的HD Map构造方法,它将每个地图元素视为有着一组等价排列方式的点集,通过这样的方式避免地图元素定义上的歧义。

The definition ambiguity of map element

地图元素可以抽象为多段线(车道线等)和多边形(人行横道等)两种。对于多段线而言,两个端点都可以视作起始点;对于多边形而言,从点集中任何一点以任意方向排列都是合理的。

图1 The definition ambiguity of map element

Permutation-based modeling of MapTR

为了弥补这种差距,MapTR使用$\mathcal{V}=(V,\Gamma)$建模每个地图元素,$V=\{v_j\}_{j=0}^{N_v-1}$表示该地图元素的点集,$\Gamma=\{\gamma_k \}$表示这个点集的一组等效的排列,包含所有可能的组织序列。
对于多段线而言,有两种等价的排列方式;对于多边形而言,有$2N_v$种等价的配列方式。

图2 llustration of permutation-based modeling of MapTR

Hierarchical matching

在此基础上,MapTR进一步引入了层级式的二分匹配,依次执行instance-level和point-level的匹配。
instance-level的匹配是在预测的instance和真值的instance之间寻找一个最优的标签分配,参考DETR的方法使用的是匈牙利匹配算法。
得到instance-level的结果之后,使用point-level找到最优的点对点的匹配,利用 Manhattan 距离度量。

Overall architecture of MapTR

MapTR的整体结构

图3 The overall architecture of MapTR

代码

代码模块划分

论文模块 代码模块 子模块 transformer layer 具体类型 输入维度 输出维度
MapEncoder img_backbone - - ResNet50 [B, N, 3, 480, 800] [B, N, 2048,15,25]
img_neck - - FPN [B, N, 2048,15,25] [B, N, 256, 15, 25]
MapTRHead BEVFormerEncoder BEVFormerLayer BEVFormer [N, 15*25, B, 256][B, 20000, 256] [B, 20000, 256]
MapDecoder MapTRDecoder DETRDecoderLayer MultiHead Attention [50*20, B, 256] [B, 20000, 256] [N, 50*20, B, 256]
Deformable Attention

BEVFormerEncoder

def forward(self, mlvl_feats, img_metas, prev_bev=None,  only_bev=False):
    bs, num_cam, _, _, _ = mlvl_feats[0].shape  # [B, N, 256, 15, 25]

    pts_embeds = self.pts_embedding.weight.unsqueeze(0)  # [1, 20, 512]
    instance_embeds = self.instance_embedding.weight.unsqueeze(1)  # [50, 1, 512]
    object_query_embeds = (pts_embeds + instance_embeds).flatten(0, 1) # [1000, 512]
    
    bev_queries = self.bev_embedding.weight  # [200*100, 256]
    bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)  # [200*100, 2, 256]

    bev_mask = torch.zeros((bs, self.bev_h, self.bev_w), device=bev_queries.device)  # [B, 200, 100]
    bev_pos = self.positional_encoding(bev_mask).flatten(2).permute(2, 0, 1)  # [200*100, B, 256]

	# add can bus signals
    can_bus = bev_queries.new_tensor([each['can_bus'] for each in kwargs['img_metas']])  # [B, 18]
    can_bus = self.can_bus_mlp(can_bus)[None, :, :]  # [1, B, 256]
    bev_queries = bev_queries + can_bus * self.use_can_bus    # [200*100, 2, 256]

    feat_flatten = []
    spatial_shapes = []
    for lvl, feat in enumerate(mlvl_feats):
        bs, num_cam, c, h, w = feat.shape
        spatial_shape = (h, w)
        feat = feat.flatten(3).permute(1, 0, 3, 2)  # [N, B, 15*25, 256]
        if self.use_cams_embeds:  # camera的所有层级特征
            feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
        feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(feat.dtype)
        spatial_shapes.append(spatial_shape)  # (15, 25)
        feat_flatten.append(feat)  # [N, B, 15*25, 256]

    feat_flatten = torch.cat(feat_flatten, 2)
    feat_flatten = feat_flatten.permute(0, 2, 1, 3)  # (num_cam, H*W, bs, embed_dims)
    spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=bev_pos.device)
    level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

    bev_embed = self.encoder(
            bev_queries,  # Q [200*100, 2, 256]
            feat_flatten, # K [N, 15*25, B, 256]
            feat_flatten, # V
            bev_h=bev_h,
            bev_w=bev_w,
            bev_pos=bev_pos,  # [200*100, B, 256]
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            prev_bev=prev_bev,
            shift=shift,
            **kwargs
        )
	return bev_embed  # [B, 20000, 256]
def forward(self,
            bev_query,
            key,
            value,
            *args,
            bev_h=None,
            bev_w=None,
            bev_pos=None,
            spatial_shapes=None,
            level_start_index=None,
            valid_ratios=None,
            prev_bev=None,
            shift=0.,
            **kwargs):
    output = bev_query  # [200*100, 2, 256]
    intermediate = []

    # reference point in 3D space, used in spatial cross-attention
    # BEVFormer的设计,bev query查询reference point周围的RoI,每个query对应4个不同高度的point,每个point采样4个特征点。
    # 通过只对参考点附近的位置进行采样,减少计算量,提高模型的收敛速度
    ref_3d = self.get_reference_points(  
        bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1),  device=bev_query.device, dtype=bev_query.dtype)
    # reference point in 2D BEV space, used in temporal cross-attention
    ref_2d = self.get_reference_points(
        bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)

    reference_points_cam, bev_mask = self.point_sampling(ref_3d, self.pc_range, kwargs['img_metas'])
    # reference_points_cam[N, B, 20000, 4, 2] 参考点的图像坐标,空间范围是20000*4
    # bev_mask[N, B, 20000, 4] 3D点在图像上是否可见

    bev_query = bev_query.permute(1, 0, 2)
    bev_pos = bev_pos.permute(1, 0, 2)
    bs, len_bev, num_bev_level, _ = ref_2d.shape
    
    hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(bs*2, len_bev, num_bev_level, 2)

    for lid, layer in enumerate(self.layers):
        # DETR Decoder layer
        output = layer(
            bev_query,
            key,
            value,
            *args,
            bev_pos=bev_pos,
            ref_2d=hybird_ref_2d,
            ref_3d=ref_3d,
            bev_h=bev_h,
            bev_w=bev_w,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            reference_points_cam=reference_points_cam,
            bev_mask=bev_mask,
            prev_bev=prev_bev,
            **kwargs)

        bev_query = output

    return output  # [B, 20000, 256]

MapTRDecoder

def forward(self,
            mlvl_feats,
            bev_queries,
            object_query_embed,
            bev_h,
            bev_w,
            grid_length=[0.512, 0.512],
            bev_pos=None,
            reg_branches=None,
            cls_branches=None,
            prev_bev=None,
            **kwargs):
    bev_embed = self.get_bev_features(...)  # Encoder的输出  [B, 20000, 256]

    bs = mlvl_feats[0].size(0)
    query_pos, query = torch.split(object_query_embed, self.embed_dims, dim=1)  # query_pos [1000, 256]
    query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
    query = query.unsqueeze(0).expand(bs, -1, -1)
    reference_points = self.reference_points(query_pos)
    reference_points = reference_points.sigmoid()
    init_reference_out = reference_points  # [B, 1000, 2]

    query = query.permute(1, 0, 2)
    query_pos = query_pos.permute(1, 0, 2)
    bev_embed = bev_embed.permute(1, 0, 2)

    inter_states, inter_references = self.decoder(
        query=query,  # [1000, B, 256]
        key=None,
        value=bev_embed,  # [20000, B, 256]
        query_pos=query_pos,  # [1000, B, 256]
        reference_points=reference_points,  # reference points of offset
        reg_branches=reg_branches,
        cls_branches=cls_branches,
        spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
        level_start_index=torch.tensor([0], device=query.device),
        **kwargs)

    inter_references_out = inter_references

    return bev_embed, inter_states, init_reference_out, inter_references_out
    # [20000, B, 256] [N, 1000, B, 256] [B, 1000, 2]    [N, B, 1000, 2]  这里的N是层数
def forward(self,
            query,
            *args,
            reference_points=None,
            reg_branches=None,
            key_padding_mask=None,
            **kwargs):
    output = query  # [1000, B, 256]
    intermediate = []
    intermediate_reference_points = []
    for lid, layer in enumerate(self.layers):
        reference_points_input = reference_points[..., :2].unsqueeze(2)  # BS NUM_QUERY NUM_LEVEL 2
        output = layer(
            output,
            *args,
            reference_points=reference_points_input,
            key_padding_mask=key_padding_mask,
            **kwargs)
        output = output.permute(1, 0, 2)  # [B, 1000, 256]

        # 根据预测的偏移量更新参考点的位置,给级联的下一个decoder layer提供参考点位置
        if reg_branches is not None:
            tmp = reg_branches[lid](output)  # [B, 1000, 2]

            assert reference_points.shape[-1] == 2

            new_reference_points = torch.zeros_like(reference_points)
            new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points[..., :2])

            new_reference_points = new_reference_points.sigmoid()

            reference_points = new_reference_points.detach()  # [B, 1000, 2]

        output = output.permute(1, 0, 2)
        intermediate.append(output)
        intermediate_reference_points.append(reference_points)

    return torch.stack(intermediate), torch.stack(intermediate_reference_points)
    #               [N, 1000, B, 256]               [N, B, 1000, 2] 

推理结果处理

使用线性层将输出结果处理成points set的形式

bev_embed, hs, init_reference, inter_references = outputs  # decoder的输出
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
outputs_pts_coords = []
for lvl in range(hs.shape[0]):  # 6
    if lvl == 0:
        reference = init_reference
    else:
        reference = inter_references[lvl - 1]
    reference = inverse_sigmoid(reference)

    outputs_class = self.cls_branches[lvl](hs[lvl].view(bs,self.num_vec, self.num_pts_per_vec,-1).mean(2))  # 线性层
    tmp = self.reg_branches[lvl](hs[lvl])  # 线性层

    assert reference.shape[-1] == 2
    tmp[..., 0:2] += reference[..., 0:2]
    tmp = tmp.sigmoid()

    outputs_coord, outputs_pts_coord = self.transform_box(tmp)
    outputs_classes.append(outputs_class)
    outputs_coords.append(outputs_coord)
    outputs_pts_coords.append(outputs_pts_coord)

outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outputs_pts_coords = torch.stack(outputs_pts_coords)
outs = {
    'bev_embed': bev_embed,  # [20000, B, 256]
    'all_cls_scores': outputs_classes,  # [N, B, 50, 3]   最多50个instance
    'all_bbox_preds': outputs_coords,  # [N, B, 50, 4]  bounding box形式
    'all_pts_preds': outputs_pts_coords,  # [N, B, 50, 20, 2] 所有instance的points set,points set的长度固定
    'enc_cls_scores': None,
    'enc_bbox_preds': None,
    'enc_pts_preds': None
}

return outs

Loss计算

MapTR的损失函数由分类损失、point2point损失和方向损失三个部分组成

分类损失使用Focal loss

point2point损失使用曼哈顿距离

方向损失使用配对的边之间的余弦相似度

代码中实际由五个部分组成(decoder输出了六层的结果,实际只取了最后一层的loss)

loss type
loss_cls Focal loss
loss_bbox L1 loss
loss_iou GIoU loss
loss_pts pts L1 loss
loss_dir pts dir cos loss

参考

MapTR Paper
MapTR Code


文章作者: Jingyi Yu
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Jingyi Yu !
  目录