跳转至

训练MobileNetV2(mmclassification)

训练代码

https://github.com/wojiazaiyugang/mmclassification

训练环境

docker pull wojiazaiyugang/mmclassification:latest

进入mmclassification文件夹pip install --no-cache-dir -e .

数据集准备

在代码中自行实现加载数据和标签的逻辑,因此数据集并没有特殊要求的格式,不过之前的实践都是一类数据放到一个文件夹

配置文件

配置文件详见代码仓库中的readme,下面的文档介绍了新建一个自定义数据集 1. 自定义数据集 mmcls/datasets/goal_classification.py

import os
import numpy as np

from .builder import DATASETS
from .base_dataset import BaseDataset


@DATASETS.register_module()
class GoalClassification(BaseDataset):
    """
    进球判定数据集
    """

    def load_annotations(self):
        """
        0是负样本:没进球
        1是正样本:进球了
        Returns:
        """
        annotations = []
        for file in os.listdir(os.path.join(self.data_prefix, "fu")):
            if not file.endswith(".jpg"):
                continue
            annotations.append(dict(img_prefix=os.path.join(self.data_prefix, "fu"),
                                    img_info={'filename': file},
                                    gt_label=np.array(0, dtype=np.int64)))
        for file in os.listdir(os.path.join(self.data_prefix, "zheng")):
            if not file.endswith(".jpg"):
                continue
            annotations.append(dict(img_prefix=os.path.join(self.data_prefix, "zheng"),
                                    img_info={'filename': file},
                                    gt_label=np.array(1, dtype=np.int64)))
        return annotations

  1. mmcls/datasets/__init__.py中注册新增的数据集

# Copyright (c) OpenMMLab. All rights reserved.
from .base_dataset import BaseDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler
from .voc import VOC
from .goal_classification import GoalClassification
from .shoot_classification import ShootClassification

__all__ = [
    'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
    'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
    'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'DATASETS', 'PIPELINES', "GoalClassification", "ShootClassification"
]
3. 新建配置文件configs/goal_classification/goal_classification.py,其中依赖的基类的配置见代码

_base_ = [
    '../_base_/models/goal_classification.py',
    '../_base_/datasets/goal_classification.py',
    '../_base_/schedules/imagenet_bs256_epochstep.py',
    '../_base_/default_runtime.py'
]

从头训练

python tools/train.py configs/goal_classification/goal_classification.py

继续训练

python tools/train.py configs/goal_classification/goal_classification.py --resume-from work_dirs/goal_classification/goal_classification_mmcls_mobilenetv2_20211007+20211103_top197.8_20211104.pth

测试

python tools/test.py configs/goal_classification/goal_classification.py work_dirs/goal_classification/latest.pth --metrics accuracy

分析及绘图

参考 https://github.com/wojiazaiyugang/mmclassification/blob/master/docs/zh_CN/tools/analysis.md

模型转onnx

python tools/deployment/pytorch2onnx.py configs/goal_classification/goal_classification.py --simplify --checkpoint work_dirs/goal_classification/epoch_264.pth --output-file goal_classification_mmcls_mobilenetv2_20211007+20211103_top197.8_20211104.onnx

评论