行人重识别02-05:fast-reid(BoT)-pytorch编程规范(fast-reid为例)2-DefaultTrainer解析-程序员宅基地

技术标签: # 目标追踪  Bot  ReID  pytorch  # 行人重识别  fast-reid  行人重识别  

以下链接是个人关于fast-reid(BoT行人重识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

行人重识别02-00:fast-reid(BoT)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

通过上一篇博客。我们已经知道继承于 HookBase 的类,都存在以下几个方法:

    def before_train(self):    # 在第一次迭代之前调用
    def after_train(self):      # 在最后一次迭代之后调用
    def before_step(self):   # 在每次迭代之前调用
    def after_step(self):      # 在每次迭代之后调用

并且已经知道他是在什么时候被带哦用,同时知道了训练的大致过程。但是hooks是如何创建的,我们需要那些hooks,我们不是很清楚,接下来我们会为大家进行讲解

DefaultTrainer

在 fastreid\engine\defaults.py 文件中,我们可以看到如下源码:

class DefaultTrainer(SimpleTrainer):
    """
    具有默认训练逻辑的培训师,继承于SimpleTrainer.主要包含了以下逻辑
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    # 根据配置文件创建optimizer, scheduler, dataloader
    1. Create model, optimizer, scheduler, dataloader from the given config.

    # 如果指定了模型权重文件,则加载模型权重
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.

    # 注册一些通用的hooks
    3. Register a few common hooks.

    #这是一个标准的简单训练模型流程,可以减少只需要标准培训工作流程的用户的代码样板,
    这意味着这门课对你的训练逻辑做了很多假设,这些假设在新的研究中很容易变得无效
    事实上,任何超出班级:‘SimpleTrainer’太多了,不适合研究。这个类的代码已经注释了它所产生的限制性假设
    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.
    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:
    # 覆盖类方法
    1. Overwrite methods of this class, OR:
    # 用法:class:`SimpleTrainer`,它只进行最小的SGD培训,而不进行其他任何操作。如果需要,可以添加自己的钩子。或者:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    #编写类似`tools/plain_train_net.py`的训练循环`.
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    # 还要注意这个类的属性,就像这个文件中的其他函数/类一样,他是不稳定的。因为它是用来表示
    “常见的默认行为”。它只能保证与fastreid中的标准模型和培训工作流一起工作。为了获得更稳定的行为,
    可以使用其他公共api编写自己的训练逻辑。
    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in fastreid.
    To obtain more stable behavior, write your own training logic with other public APIs.
    Attributes:
        scheduler: # 学习策略
        checkpointer (DetectionCheckpointer): # 模型参数检测加载
        cfg (CfgNode):cfg配置文件
    Examples:
    .. code-block:: python
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # 创建记录答应日志的类对象
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        # 对于培训,用DDP包装。但不需要这个来推断。
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )

        super().__init__(model, data_loader, optimizer)

        # 设置学习率衰减策略
        self.scheduler = self.build_lr_scheduler(cfg, optimizer)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        # 加载指定的模型参数
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        # 初始化迭代次数
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        # 创建hooks,并且注册hooks
        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        如果resume==True表示接着之前的迭代次数训练,否则从0开始训练
        If `resume==True`, and last checkpoint exists, resume from it.
        Otherwise, load a model specified by the config.
        Args:
            resume (bool): whether to do resume or not
        """
		......
    def build_hooks(self):
        """
        构建一个默认的hooks列表,包含了timing,checkpointing, lr scheduling, precise BN, writing events
        可以理解为把这些类,或者函数放入到一个容器中,需要他的时候再把他取出来进行调用
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
		......

        return ret

    def build_writers(self):
        """
        主要用于写入log日志等等
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        # 调用父类的训练函数
        super().train(self.start_iter, self.max_iter)
        # 等待训练完成之后进行,返回最后一次的评估结果
        if comm.is_main_process():
            assert hasattr(
                self, "_last_eval_results"
            ), "No evaluation results obtained during training!"
            # verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        根据配置信息cfg构建模型
        Returns:
            torch.nn.Module:
        It now calls :func:`fastreid.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        # logger = logging.getLogger(__name__)
        # logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        根据配置参数构建优化器
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        根据配置参数指定学习率衰减策略
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        构建一个训练数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger(__name__)
        logger.info("Prepare training set")
        return build_reid_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        构建一个测试数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_reid_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, num_query, output_dir=None):
        """
        构建评估器
        """
        return ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        对模型进行评估
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.
        Returns:
            dict: a dict of result metrics
        """
        # 用于log日志的保存
        logger = logging.getLogger(__name__)
        # 检测是evaluators是否为正确的评估器
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]

        # 如果evaluators不为none,则对evaluators的长度进行检测
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )
        # 创建一个字典,用于结果保存
        results = OrderedDict()
        
        # 对多个数据集进行评估 
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            # 进行log打印,并且创建评估数据迭代器
            logger.info("Prepare testing set")
            data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
            
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, num_query)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {
    }
                    continue
            # 对单个评估数据集进行推断,并且获得推断结果
            results_i = inference_on_dataset(model, data_loader, evaluator)
            # 保存数据集对应的推断结果
            results[dataset_name] = results_i

        # 如果为主进程,则返回一个评估之后的字典
        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results
            )
            # 使用csv的格式打印评估结果
            print_csv_format(results)

        if len(results) == 1: results = list(results.values())[0]

        return results

    @staticmethod
    def auto_scale_hyperparams(cfg, data_loader):
        r"""
        根据传入的cfg,推算出一些cfg配置参数,如总迭代次数等等
        This is used for auto-computation actual training iterations,
        because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
        so we need to convert specific hyper-param to training iterations.
        """

        cfg = cfg.clone()
		......
        return cfg

代码领读

从上面代码的注释中,我们可以看到如下:

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

这里就是创建数据迭代器,优化器,以及模型的过程。对于 def build_hooks(self) 函数,其会去构建所有的 hooks,如 timing,checkpointing, lr scheduling, precise BN, writing events 等等。这些都是hooks,继承于HookBase。

在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43013761/article/details/108051799

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签