MMCV组件Runner
Runner负责所有框架的训练过程调度。
配合各类的Hook,对外提供灵活的扩展能力。
1 2 3 4 5 6 7 8 9 10 11 12
| def val(self, data_loader, **kwargs): self.model.eval() self.mode = 'val' self.data_loader = data_loader self.call_hook('before_val_epoch') time.sleep(2) for i, data_batch in enumerate(self.data_loader): self._inner_iter = i self.call_hook('before_val_iter') self.run_iter(data_batch, train_mode=False) self.call_hook('after_val_iter') self.call_hook('after_val_epoch')
|
在测是流程中,call_hook函数来按照优先级执行hook的不同阶段(e.g. after_val_epoch)的功能。
1 2 3
| def call_hook(self, fn_name): for hook in self._hooks: getattr(hook, fn_name)(self)
|
如何修改mmdetection中的验证流程代码
在配置中默认validate变量为true
注册eval hooks,Hook类为EvalHook
在runner的流程中,在一个epoch后会执行EvalHook的after_train_epoch代码,其调用_do_evaluate
1 2 3 4 5
| def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" results = self.test_fn(runner.model, self.dataloader) runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) key_score = self.evaluate(runner, results)
|
之后调用runner类的evaluate函数,调用datasets的evaluate函数,并将结果写入到runner.logger中,其是一个OrderedDict()字典结构(按照key插入顺序输出),并将ready状态设置为True
修改自己数据集的evalute函数,如imagenet数据集的evaluate函数