以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解
前言通过前面的博客,我们可以知道主干网络为作者修改过的lib/networks/pvnet/resnet18.py,修改的内容在上篇博客中有具体提到,该网络结构简单,我就不做详细的介绍了,但是其中有个比较重要的地方,我们还是需要重点分析的,在lib/networks/pvnet/resnet18.py代码中,我们可以看到如下部分:
def decode_keypoint(self, output):
vertex = output['vertex'].permute(0, 2, 3, 1)
# vn_2 = 9*2 =18
b, h, w, vn_2 = vertex.shape
# 把x,y值分离出来,分别占用一个纬度
vertex = vertex.view(b, h, w, vn_2//2, 2)
# 获得前景(目标物体)对应的mask
mask = torch.argmax(output['seg'], 1)
# 如果使用了不确定性的pnp
if cfg.test.un_pnp:
# 基于RANSAC进行投票选举关键点
mean = ransac_voting_layer_v3(mask, vertex, 512, inlier_thresh=0.99)
# 获得关键点的概率分布
kpt_2d, var = estimate_voting_distribution_with_mean(mask, vertex, mean)
output.update({'mask': mask, 'kpt_2d': kpt_2d, 'var': var})
else:
kpt_2d = ransac_voting_layer_v3(mask, vertex, 128, inlier_thresh=0.99, max_num=100)
output.update({'mask': mask, 'kpt_2d': kpt_2d})
其中ransac_voting_layer_v3函数是比较重要的。
ransac_voting_layer_v3对于该函数的注释如下:
def ransac_voting_layer_v3(mask, vertex, round_hyp_num, inlier_thresh=0.999, confidence=0.99, max_iter=20,
min_num=5, max_num=30000):
'''
:param mask: [b,h,w]
:param vertex: [b,h,w,vn,2]
:param round_hyp_num:
:param inlier_thresh:
:return: [b,vn,2]
'''
b, h, w, vn, _ = vertex.shape
batch_win_pts = []
# 分别对每张图片进行处理
for bi in range(b):
#
hyp_num = 0
# 获得当前图片对应的mask
cur_mask = (mask[bi]).byte()
# 计算前景mask的和
foreground_num = torch.sum(cur_mask)
# if too few points, just skip it
# 如果其前景的像数太少,则设置win_pts为0,并且continue跳过该图像的处理
if foreground_num max_num:
# 随机选取一定数目的像素点,得到新的mask
selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
selected_mask = (selection max_iter:
break
# compute mean intersection again
normal = torch.zeros_like(direct) # [tn,vn,2]
# x,y的坐标互换
normal[:, :, 0] = direct[:, :, 1]
normal[:, :, 1] = -direct[:, :, 0]
# 0表示局外,1表示局内
all_inlier = torch.zeros([1, vn, tn], dtype=torch.uint8, device=mask.device)
all_win_pts = torch.unsqueeze(all_win_pts, 0) # [1,vn,2]
# 再一次假设交叉点(关键点)
ransac_voting.voting_for_hypothesis(direct, coords, all_win_pts, all_inlier, inlier_thresh) # [1,vn,tn]
# 后续的操作本人不是很了解,感觉all_win_pts以及获得了关键点的位置,不知道为何还要有如下操作
# 估计是为了剔除局外的投票者
# coords [tn,2] normal [vn,tn,2]
all_inlier = torch.squeeze(all_inlier.float(), 0) # [vn,tn]
# 矢量x,y的坐标互换
normal = normal.permute(1, 0, 2) # [vn,tn,2]
# 把局外投票者的方向全部清零
normal = normal*torch.unsqueeze(all_inlier, 2) # [vn,tn,2] outlier is all zero
# 注意,这里的coords表示投票者vector的坐标,normal表示其方向
b = torch.sum(normal*torch.unsqueeze(coords, 0), 2) # [vn,tn]
# 剔除局外投票者,重新进行投票,获得精确的结果
# 获得ATA矩阵,以及ATB矩阵
ATA = torch.matmul(normal.permute(0, 2, 1), normal) # [vn,2,2]
ATb = torch.sum(normal*torch.unsqueeze(b, 2), 1) # [vn,2]
# try:
# 根据ATA以及ATb矩阵求得坐标值
# [vn,2,2] * [vn,2,1] = [vn,2,1]
all_win_pts = torch.matmul(b_inv(ATA), torch.unsqueeze(ATb, 2)) # [vn,2,1]
# except:
# __import__('ipdb').set_trace()
batch_win_pts.append(all_win_pts[None,:,:, 0])
batch_win_pts = torch.cat(batch_win_pts)
return batch_win_pts
领读
对于RANSAC的理解,大家可以参考一下这篇博客: RANSAC算法理解:https://blog.csdn.net/robinhjwy/article/details/79174914 总的来说,主要步骤如下:
1.像素筛选:如果前景像素太少,则该张图像忽略,如果前景像素太多,则随机剔除部分像素
2.使用ransac_voting.generate_hypothesis获得假设,即关键点可能存在的位置。
3.随机抽取样本(direct)对假设出来的关键点位置进行投票。
4.循环执行2,3步骤,直到投票的置信度达到标准
5.使用循环迭代出来的最好模型(摒弃了局外样本),再一次去生成假设坐标,并进行投票。
总的来说,就是一直假设坐标,投票,刷新最高票数坐标,假设坐标,投票,刷新最高票数坐标…一直这样下去,知道票数达到标准才停止。