您当前的位置: 首页 > 
  • 3浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

姿态估计2-08:PVNet(6D姿态估计)-源码无死角解析(4)-RANSAC投票机制

江南才尽,年少无知! 发布时间:2020-07-19 11:39:57 ,浏览量:3

以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。

姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解

前言

通过前面的博客,我们可以知道主干网络为作者修改过的lib/networks/pvnet/resnet18.py,修改的内容在上篇博客中有具体提到,该网络结构简单,我就不做详细的介绍了,但是其中有个比较重要的地方,我们还是需要重点分析的,在lib/networks/pvnet/resnet18.py代码中,我们可以看到如下部分:

    def decode_keypoint(self, output):
        vertex = output['vertex'].permute(0, 2, 3, 1)
        # vn_2 = 9*2 =18
        b, h, w, vn_2 = vertex.shape
        # 把x,y值分离出来,分别占用一个纬度
        vertex = vertex.view(b, h, w, vn_2//2, 2)
        # 获得前景(目标物体)对应的mask
        mask = torch.argmax(output['seg'], 1)
        # 如果使用了不确定性的pnp
        if cfg.test.un_pnp:
            # 基于RANSAC进行投票选举关键点
            mean = ransac_voting_layer_v3(mask, vertex, 512, inlier_thresh=0.99)
            # 获得关键点的概率分布
            kpt_2d, var = estimate_voting_distribution_with_mean(mask, vertex, mean)
            output.update({'mask': mask, 'kpt_2d': kpt_2d, 'var': var})
        else:
            kpt_2d = ransac_voting_layer_v3(mask, vertex, 128, inlier_thresh=0.99, max_num=100)
            output.update({'mask': mask, 'kpt_2d': kpt_2d})

其中ransac_voting_layer_v3函数是比较重要的。

ransac_voting_layer_v3

对于该函数的注释如下:

def ransac_voting_layer_v3(mask, vertex, round_hyp_num, inlier_thresh=0.999, confidence=0.99, max_iter=20,
                           min_num=5, max_num=30000):
    '''
    :param mask:      [b,h,w]
    :param vertex:    [b,h,w,vn,2]
    :param round_hyp_num:
    :param inlier_thresh:
    :return: [b,vn,2]
    '''
    b, h, w, vn, _ = vertex.shape
    batch_win_pts = []
    # 分别对每张图片进行处理
    for bi in range(b):
        #
        hyp_num = 0
        # 获得当前图片对应的mask
        cur_mask = (mask[bi]).byte()
        # 计算前景mask的和
        foreground_num = torch.sum(cur_mask)

        # if too few points, just skip it
        # 如果其前景的像数太少,则设置win_pts为0,并且continue跳过该图像的处理
        if foreground_num  max_num:
            # 随机选取一定数目的像素点,得到新的mask
            selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
            selected_mask = (selection  max_iter:
                break

        # compute mean intersection again
        normal = torch.zeros_like(direct)   # [tn,vn,2]
        # x,y的坐标互换
        normal[:, :, 0] = direct[:, :, 1]
        normal[:, :, 1] = -direct[:, :, 0]
        # 0表示局外,1表示局内
        all_inlier = torch.zeros([1, vn, tn], dtype=torch.uint8, device=mask.device)
        all_win_pts = torch.unsqueeze(all_win_pts, 0)  # [1,vn,2]

        # 再一次假设交叉点(关键点)
        ransac_voting.voting_for_hypothesis(direct, coords, all_win_pts, all_inlier, inlier_thresh)  # [1,vn,tn]
        # 后续的操作本人不是很了解,感觉all_win_pts以及获得了关键点的位置,不知道为何还要有如下操作
        # 估计是为了剔除局外的投票者

        # coords [tn,2] normal [vn,tn,2]
        all_inlier = torch.squeeze(all_inlier.float(), 0)              # [vn,tn]
        # 矢量x,y的坐标互换
        normal = normal.permute(1, 0, 2)                                # [vn,tn,2]
        # 把局外投票者的方向全部清零
        normal = normal*torch.unsqueeze(all_inlier, 2)                 # [vn,tn,2] outlier is all zero
        # 注意,这里的coords表示投票者vector的坐标,normal表示其方向
        b = torch.sum(normal*torch.unsqueeze(coords, 0), 2)             # [vn,tn]


        # 剔除局外投票者,重新进行投票,获得精确的结果
        # 获得ATA矩阵,以及ATB矩阵
        ATA = torch.matmul(normal.permute(0, 2, 1), normal)              # [vn,2,2]
        ATb = torch.sum(normal*torch.unsqueeze(b, 2), 1)                # [vn,2]
        # try:
        # 根据ATA以及ATb矩阵求得坐标值
        # [vn,2,2] * [vn,2,1] = [vn,2,1]
        all_win_pts = torch.matmul(b_inv(ATA), torch.unsqueeze(ATb, 2)) # [vn,2,1]
        # except:
        #    __import__('ipdb').set_trace()
        batch_win_pts.append(all_win_pts[None,:,:, 0])

    batch_win_pts = torch.cat(batch_win_pts)
    return batch_win_pts
领读

对于RANSAC的理解,大家可以参考一下这篇博客: RANSAC算法理解:https://blog.csdn.net/robinhjwy/article/details/79174914 总的来说,主要步骤如下:

1.像素筛选:如果前景像素太少,则该张图像忽略,如果前景像素太多,则随机剔除部分像素
2.使用ransac_voting.generate_hypothesis获得假设,即关键点可能存在的位置。
3.随机抽取样本(direct)对假设出来的关键点位置进行投票。
4.循环执行2,3步骤,直到投票的置信度达到标准
5.使用循环迭代出来的最好模型(摒弃了局外样本),再一次去生成假设坐标,并进行投票。

总的来说,就是一直假设坐标,投票,刷新最高票数坐标,假设坐标,投票,刷新最高票数坐标…一直这样下去,知道票数达到标准才停止。

在这里插入图片描述

关注
打赏
1592542134
查看更多评论
立即登录/注册

微信扫码登录

0.0426s