您当前的位置: 首页 >  架构
  • 2浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

动作识别0-07:mmaction2(SlowFast)-源码无死角解析(3)-训练架构总览-2

江南才尽,年少无知! 发布时间:2020-08-05 12:51:18 ,浏览量:2

以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。

动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

我们在训练模型的时候,是执行如下指令:

python tools/train.py configs/recognition/slowfast/my_slowfast_r50_4x16x1_256e_ucf101_rgb.py   --work-dir work_dirs/my_slowfast_r50_4x16x1_256e_ucf101_rgb    --validate --seed 0 --deterministic

前面的博客我们对 tools/train.py 进行了分析,可以知道其会调用到项目根目录下的 mmaction/apis/train.py 中的 def train_model 函数,该函数会调用以下函数:

    # 创建模型训练的类
    runner = EpochBasedRunner(model,optimizer=optimizer,work_dir=cfg.work_dir,logger=logger,meta=meta)
	# 获得时间戳
    runner.timestamp = timestamp

    # register hooks,注册训练模型的相关组件,如学习率,优化器,预训练模型等
    runner.register_training_hooks(cfg.lr_config, optimizer_config,cfg.checkpoint_config, cfg.log_config,cfg.get('momentum_config', None))

   # 对模型进行评估验证
   if validate:
		# 注册训练的钩子
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # 加载预训练模型
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)

    # 正式开始模型训练
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

从上面的总结,可以看到其核心都在于EpochBasedRunner这个类创建的对象runner。对于EpochBasedRunner的注释,本人如下。

EpochBasedRunner
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings

import torch

import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info


class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        """
        :param data_loader:训练数据迭代器
        :param kwargs:模型模型的一些相关参数
        :return:
        """
        # 设置模型为训练模式
        self.model.train()
        self.mode = 'train'
        # 赋值训练数据迭代器
        self.data_loader = data_loader
        # 获得每个epoch最大的迭代次数
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch')
        # 防止转型期可能出现的僵局
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        # 循环迭代数据进行训练
        for i, data_batch in enumerate(data_loader):
            # 记录迭代次数
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # 如果不需要预处理,则直接训练当前batch的数据
            if self.batch_processor is None:
                outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
            # 如果需要预处理,则先进行数据预处理
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=True, **kwargs)

            # 如果输出结果不是一个字典则报错
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.train_step()"'
                                ' must return a dict')
            # 进行log打印
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            # 进行迭代后的处理,如反向传播
            self.outputs = outputs
            self.call_hook('after_train_iter')
            self._iter += 1
        # 迭代一个epoch的后处理
        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        # 配置为验证模式
        self.model.eval()
        self.mode = 'val'
        # 加载验证数据迭代器
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        # 循环迭代数据进行训练
        for i, data_batch in enumerate(data_loader):
            # 记录迭代次数
            self._inner_iter = i
            self.call_hook('before_val_iter')
            # 设置不进行梯度传播
            with torch.no_grad():
                # 如果需要进行预处理再进行反向传播,则先进行预处理
                if self.batch_processor is None:
                    outputs = self.model.val_step(data_batch, self.optimizer,
                                                  **kwargs)
                else:
                    outputs = self.batch_processor(
                        self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.val_step()"'
                                ' must return a dict')
            # 打印los日志
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])

            self.outputs = outputs
            # 进行验证后的处理,如反向传播
            self.call_hook('after_val_iter')
        # 迭代一个epoch的后处理
        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.该为一个列表,其中可以包含了训练数据迭代器,
                以及验证数据迭代器。

            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.该为包含了多个(phase, epochs)元组形式的数组,如
                [('train', 2), ('val', 1)]表示训练两个epoch之后进行一次验证

            max_epochs (int): Total training epochs.表示最大的迭代次数
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        # 对输入数据进行判断,看是否符合规范
        self._max_epochs = max_epochs
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        # 进行loger信息打印
        #work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        # 如果当前的self.epoch小于最大的max_epochs,则继续训练
        while self.epoch             
关注
打赏
1592542134
查看更多评论
0.0420s