以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解
极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解
前言根据上一篇博客,我们知道数据加载的类主要涉及到: mmaction/datasets/rawframe_dataset.py中的 class RawframeDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset): 或者 mmaction/datasets/video_dataset.py中的 class VideoDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset): 对于BaseDataset在上篇博客已经详细介绍了,该篇博客我们就来分析一下RawframeDataset以及VideoDataset。代码的注释如下。
VideoDatasetmmaction/datasets/rawframe_dataset.py中的 class RawframeDataset(BaseDataset): 文章最后有总结分析 \color{red}{文章最后有总结分析} 文章最后有总结分析
import copy
import os.path as osp
import torch
from mmcv.utils import print_log
from ..core import mean_average_precision, mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS
# 把这个类注册到DATASETS容器之中
@DATASETS.register_module()
class RawframeDataset(BaseDataset):
"""Rawframe dataset for action recognition.
数据集加载原始帧并应用指定的转换,然后返回包含帧张量和其他信息的字典
The dataset loads raw frames and apply specified transforms to return a
dict containing the frame tensors and other information.
注释文件文件存在多行,每行标识了视频帧的存放目录,视频帧的总和,
以及视频对应的标签。他们都是使用空格隔开的,
The ann_file is a text file with multiple lines, and each line indicates
the directory to frames of a video, total frames of the video and
the label of a video, which are split with a whitespace.
下面是一个注释文件的例子
Example of a annotation file:
.. code-block:: txt
some/directory-1 163 1
some/directory-2 122 1
some/directory-3 258 2
some/directory-4 234 2
some/directory-5 295 3
some/directory-6 121 3
如果一个视频存在多个标签文件,注释如下
Example of a multi-class annotation file:
.. code-block:: txt
some/directory-1 163 1 3 5
some/directory-2 122 1 2
some/directory-3 258 2
some/directory-4 234 2 4 6 8
some/directory-5 295 3
some/directory-6 121 3
Args:
ann_file (str): Path to the annotation file.注释文件的路径
pipeline (list[dict | callable]): A sequence of data transforms,数据转换序列
data_prefix (str): Path to a directory where videos are held.存放视频的目录
Default: None.
在构建测试或验证数据集时需要设置为ture
test_mode (bool): Store True when building test or validation dataset.
Default: False.
# 帧图片名的模板
filename_tmpl (str): Template for each filename.
Default: 'img_{:05}.jpg'.
# 是否进行多标签的训练或者测试
multi_class (bool): Determines whether it is a multi-class
recognition dataset. Default: False.
# 数据集的类别数目
num_classes (int): Number of classes in the dataset. Default: None.
# 数据的格式,默认为RGB
modality (str): Modality of data. Support 'RGB', 'Flow'.
Default: 'RGB'.
"""
def __init__(self,
ann_file, # 注释文件的路径
pipeline, # 数据转换序列
data_prefix=None, # 存放视频的目录
test_mode=False, # 在构建测试或验证数据集时需要设置为ture
filename_tmpl='img_{:05}.jpg', # 帧图片名的模板
multi_class=False, # 是否进行多标签的训练或者测试
num_classes=None, # 数据集的类别数目
modality='RGB'): # 数据的格式,默认为RGB
# 调用父类的初始化函数
super().__init__(ann_file, pipeline, data_prefix, test_mode,
multi_class, num_classes, modality)
# 帧图片名的模板
self.filename_tmpl = filename_tmpl
def load_annotations(self):
"""Load annotation file to get video information.
加载注释文件,获得视频信息
"""
video_infos = []
with open(self.ann_file, 'r') as fin:
# 循环读取注释文件的每一行数据
for line in fin:
# 先去除首位的空格,然后进行分割
line_split = line.strip().split()
# 如果是进行多标签的训练
if self.multi_class:
# 检测标注信息的self.num_classes是否为None
assert self.num_classes is not None
# 获得存放图片帧的目录,其中的帧数总数,以及该视频的标签类别
(frame_dir, total_frames,label) = (line_split[0], line_split[1], line_split[2:])
# 一次把标签转化为int型
label = list(map(int, label))
# 把标签转化为onehot格式
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
# 如果不是进行多标签的训练
else:
# 获得存放视频帧的目录,帧总数,以及对应的视频标签
frame_dir, total_frames, label = line_split
label = int(label)
# 把前缀目录凭借起来
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
# 把解析之后每行的注释信息添加到video_infos之中
video_infos.append(
dict(
frame_dir=frame_dir,
total_frames=int(total_frames),
label=onehot if self.multi_class else label))
return video_infos
def prepare_train_frames(self, idx):
"""Prepare the frames for training given the index.
根据idx,对训练数据进行序列转换"""
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
return self.pipeline(results)
def prepare_test_frames(self, idx):
"""Prepare the frames for testing given the index.
根据idx,对训练数据进行序列转换"""
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
return self.pipeline(results)
# 重写评估函数
def evaluate(self,
results, # 网络推断的结果
metrics='top_k_accuracy', # 度量准确率的方式
topk=(1, 5), # 如果前topk个预测对了,则认为其预测正确
logger=None):
"""Evaluation in rawframe dataset.
Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'top_k_accuracy'.
logger (obj): Training logger. Defaults: None.
topk (int | tuple[int]): K value for top_k_accuracy metric.
Defaults: (1, 5).
logger (logging.Logger | None): Logger for recording.
Default: None.
Returns:
dict: Evaluation results dict.
"""
# 如果输入的results不是列表则报错
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
# 如果输入的topk不为整形获得元组则报错
if not isinstance(topk, (int, tuple)):
raise TypeError(
f'topk must be int or tuple of int, but got {type(topk)}')
# 如果topk为单个整型
if isinstance(topk, int):
topk = (topk, )
#
metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
allowed_metrics = [
'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision'
]
# 如果评估的方式不在allowed_metrics中则报错
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
# 存储评估的结果
eval_results = {}
# 从注释文件获得的标签gt_labels
gt_labels = [ann['label'] for ann in self.video_infos]
for metric in metrics:
#添加log信息
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
# 如果评估的方式为top_k_accuracy
if metric == 'top_k_accuracy':
top_k_acc = top_k_accuracy(results, gt_labels, topk)
log_msg = []
for k, acc in zip(topk, top_k_acc):
eval_results[f'top{k}_acc'] = acc
log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue
# 如果评估的方式为mean_class_accuracy
if metric == 'mean_class_accuracy':
mean_acc = mean_class_accuracy(results, gt_labels)
eval_results['mean_class_accuracy'] = mean_acc
log_msg = f'\nmean_acc\t{mean_acc:.4f}'
print_log(log_msg, logger=logger)
continue
# 如果评估的方式为mean_average_precision
if metric == 'mean_average_precision':
gt_labels = [label.cpu().numpy() for label in gt_labels]
mAP = mean_average_precision(results, gt_labels)
eval_results['mean_average_precision'] = mAP
log_msg = f'\nmean_average_precision\t{mAP:.4f}'
print_log(log_msg, logger=logger)
continue
return eval_results
VideoDataset
mmaction/datasets/video_dataset.py中的 class VideoDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset):
import os.path as osp
import torch
from mmcv.utils import print_log
from ..core import mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS
@DATASETS.register_module()
class VideoDataset(BaseDataset):
"""Video dataset for action recognition.
直接加载视频源数据,经过指定转换之后返回一个包含多个frame tensors的字典,其中还包含了一些其他的信息
The dataset loads raw videos and apply specified transforms to return a
dict containing the frame tensors and other information.
注释文件文件存在多行,每行标识了视频存放的路径,以及视频对应的标签。他们都是使用空格隔开的,
The ann_file is a text file with multiple lines, and each line indicates
a sample video with the filepath and label, which are split with a
whitespace. Example of a annotation file:
.. code-block:: txt
some/path/000.mp4 1
some/path/001.mp4 1
some/path/002.mp4 2
some/path/003.mp4 2
some/path/004.mp4 3
some/path/005.mp4 3
"""
def load_annotations(self):
"""Load annotation file to get video information.
加载视频注释文件"""
video_infos = []
with open(self.ann_file, 'r') as fin:
# 对文件的每一行进行处理
for line in fin:
# 先去除首位的空格,然后进行分割
line_split = line.strip().split()
# 如果是进行多标签的训练
if self.multi_class:
# 检测标注信息的self.num_classes是否为None
assert self.num_classes is not None
# 获取视频的路径名称,以及对应的类别标签
filename, label = line_split[0], line_split[1:]
# 分别把每个标签转化为整形,然后变换成onehot格式
label = list(map(int, label))
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
# 如果不是进行多标签的训练
else:
# 获取视频的名称,以及对应的类别标签
filename, label = line_split
label = int(label)
# 把前缀目录个视频名称拼接起来
if self.data_prefix is not None:
filename = osp.join(self.data_prefix, filename)
# 把视频信息添加到 video_infos之中
video_infos.append(
dict(
filename=filename,
label=onehot if self.multi_class else label))
return video_infos
def evaluate(self,
results, # 网络推断的结果
metrics='top_k_accuracy', # 度量准确率的方式
topk=(1, 5), # 如果前topk个预测对了,则认为其预测正确
logger=None):
"""Evaluation in rawframe dataset.
Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'top_k_accuracy'.
logger (obj): Training logger. Defaults: None.
topk (tuple[int]): K value for top_k_accuracy metric.
Defaults: (1, 5).
logger (logging.Logger | None): Logger for recording.
Default: None.
Return:
dict: Evaluation results dict.
"""
# 如果输入的results不是列表则报错
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
# 如果输入的topk不为整形获得元组则报错
if not isinstance(topk, (int, tuple)):
raise TypeError(
f'topk must be int or tuple of int, but got {type(topk)}')
# 如果评估的方式不在allowed_metrics中则报错
metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
allowed_metrics = ['top_k_accuracy', 'mean_class_accuracy']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
# 存储评估的结果
eval_results = {}
# 从注释文件获得的标签gt_labels
gt_labels = [ann['label'] for ann in self.video_infos]
for metric in metrics:
#添加log信息
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
# 如果评估的方式为top_k_accuracy
if metric == 'top_k_accuracy':
top_k_acc = top_k_accuracy(results, gt_labels, topk)
log_msg = []
for k, acc in zip(topk, top_k_acc):
eval_results[f'top{k}_acc'] = acc
log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue
# 如果评估的方式为mean_class_accuracy
if metric == 'mean_class_accuracy':
mean_acc = mean_class_accuracy(results, gt_labels)
eval_results['mean_class_accuracy'] = mean_acc
log_msg = f'\nmean_acc\t{mean_acc:.4f}'
print_log(log_msg, logger=logger)
continue
return eval_results
总结
我们无论my_slowfast_r50_4x16x1_256e_ucf101_rgb.py中设置为如下的情况(加载视频帧):
dataset_type = 'RawframeDataset'
data_root = 'data/ucf101/rawframes'
data_root_val = 'data/ucf101/rawframes'
ann_file_train = 'data/ucf101/ucf101_train_split_1_rawframes.txt'
ann_file_val = 'data/ucf101/ucf101_val_split_1_rawframes.txt'
ann_file_test = 'data/ucf101/ucf101_val_split_1_rawframes.txt'
还是设置成(直接加载视频源数据):
dataset_type = 'VideoDataset'
data_root = 'data/ucf101/videos'
data_root_val = 'data/ucf101/videos'
ann_file_train = 'data/ucf101/ucf101_train_split_1_videos.txt'
ann_file_val = 'data/ucf101/ucf101_val_split_1_videos.txt'
ann_file_test = 'data/ucf101/ucf101_val_split_1_videos.txt'
train_pipeline = [
dict(type='DecordInit'),
#dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1),
#dict(type='FrameSelector'),
val_pipeline = [
dict(type='DecordInit'),
#dict(type='SampleFrames',clip_len=32,frame_interval=2,num_clips=1,test_mode=True),
dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1, test_mode=True),
#dict(type='FrameSelector'),
test_pipeline = [
dict(type='DecordInit'),
#dict(type='SampleFrames',clip_len=32,frame_interval=2,num_clips=1,test_mode=True),
dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1, test_mode=True),
#dict(type='FrameSelector'),
dict(type='DecordDecode'),
其数据的输出都是BaseDataset类中如下函数的返回结果:
def __getitem__(self, idx):
"""Get the sample for either training or testing given index.
根据训练或者测试模式,进行不同的数据转换
"""
if self.test_mode:
return self.prepare_test_frames(idx)
else:
return self.prepare_train_frames(idx)
那么他输出的到底是什么东西呢?本人截图如下: 其上的imgs形状为NCTHW = [1,3,16,224,224], label形状为[1]。
1.这里的224x224表示图片的分辨率。
2.其中的16来自我们cfg文件设置的clip_len=16,表示选取了16帧图像
3.然后就剩下一个3了,3表示随机选取出来的每一帧,都进行了3次随机剪裁。
好了,到这里为止,我相信大家应该是十分的清楚数据迭代器获得的数据是什么了。