以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解
极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解
前言我们在训练模型的时候,是执行如下指令:
python tools/train.py configs/recognition/slowfast/my_slowfast_r50_4x16x1_256e_ucf101_rgb.py --work-dir work_dirs/my_slowfast_r50_4x16x1_256e_ucf101_rgb --validate --seed 0 --deterministic
前面的博客我们对 tools/train.py 进行了分析,可以知道其会调用到项目根目录下的 mmaction/apis/train.py 中的 def train_model 函数,该函数会调用以下函数:
# 创建模型训练的类
runner = EpochBasedRunner(model,optimizer=optimizer,work_dir=cfg.work_dir,logger=logger,meta=meta)
# 获得时间戳
runner.timestamp = timestamp
# register hooks,注册训练模型的相关组件,如学习率,优化器,预训练模型等
runner.register_training_hooks(cfg.lr_config, optimizer_config,cfg.checkpoint_config, cfg.log_config,cfg.get('momentum_config', None))
# 对模型进行评估验证
if validate:
# 注册训练的钩子
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# 加载预训练模型
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
# 正式开始模型训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
从上面的总结,可以看到其核心都在于EpochBasedRunner这个类创建的对象runner。对于EpochBasedRunner的注释,本人如下。
EpochBasedRunner# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings
import torch
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
"""
:param data_loader:训练数据迭代器
:param kwargs:模型模型的一些相关参数
:return:
"""
# 设置模型为训练模式
self.model.train()
self.mode = 'train'
# 赋值训练数据迭代器
self.data_loader = data_loader
# 获得每个epoch最大的迭代次数
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
# 防止转型期可能出现的僵局
time.sleep(2) # Prevent possible deadlock during epoch transition
# 循环迭代数据进行训练
for i, data_batch in enumerate(data_loader):
# 记录迭代次数
self._inner_iter = i
self.call_hook('before_train_iter')
# 如果不需要预处理,则直接训练当前batch的数据
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
# 如果需要预处理,则先进行数据预处理
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
# 如果输出结果不是一个字典则报错
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
# 进行log打印
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
# 进行迭代后的处理,如反向传播
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
# 迭代一个epoch的后处理
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
# 配置为验证模式
self.model.eval()
self.mode = 'val'
# 加载验证数据迭代器
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
# 循环迭代数据进行训练
for i, data_batch in enumerate(data_loader):
# 记录迭代次数
self._inner_iter = i
self.call_hook('before_val_iter')
# 设置不进行梯度传播
with torch.no_grad():
# 如果需要进行预处理再进行反向传播,则先进行预处理
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
# 打印los日志
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
# 进行验证后的处理,如反向传播
self.call_hook('after_val_iter')
# 迭代一个epoch的后处理
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.该为一个列表,其中可以包含了训练数据迭代器,
以及验证数据迭代器。
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.该为包含了多个(phase, epochs)元组形式的数组,如
[('train', 2), ('val', 1)]表示训练两个epoch之后进行一次验证
max_epochs (int): Total training epochs.表示最大的迭代次数
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
# 对输入数据进行判断,看是否符合规范
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
# 进行loger信息打印
#work_dir = self.work_dir if self.work_dir is not None else 'NONE'
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
# 如果当前的self.epoch小于最大的max_epochs,则继续训练
while self.epoch
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?