训练MobileNetV2(mmclassification)¶
训练代码¶
https://github.com/wojiazaiyugang/mmclassification
训练环境¶
docker pull wojiazaiyugang/mmclassification:latest
进入mmclassification文件夹pip install --no-cache-dir -e .
数据集准备¶
在代码中自行实现加载数据和标签的逻辑,因此数据集并没有特殊要求的格式,不过之前的实践都是一类数据放到一个文件夹
配置文件¶
配置文件详见代码仓库中的readme,下面的文档介绍了新建一个自定义数据集 1. 自定义数据集 mmcls/datasets/goal_classification.py
import os
import numpy as np
from .builder import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class GoalClassification(BaseDataset):
"""
进球判定数据集
"""
def load_annotations(self):
"""
0是负样本:没进球
1是正样本:进球了
Returns:
"""
annotations = []
for file in os.listdir(os.path.join(self.data_prefix, "fu")):
if not file.endswith(".jpg"):
continue
annotations.append(dict(img_prefix=os.path.join(self.data_prefix, "fu"),
img_info={'filename': file},
gt_label=np.array(0, dtype=np.int64)))
for file in os.listdir(os.path.join(self.data_prefix, "zheng")):
if not file.endswith(".jpg"):
continue
annotations.append(dict(img_prefix=os.path.join(self.data_prefix, "zheng"),
img_info={'filename': file},
gt_label=np.array(1, dtype=np.int64)))
return annotations
mmcls/datasets/__init__.py
中注册新增的数据集
# Copyright (c) OpenMMLab. All rights reserved.
from .base_dataset import BaseDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler
from .voc import VOC
from .goal_classification import GoalClassification
from .shoot_classification import ShootClassification
__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', "GoalClassification", "ShootClassification"
]
configs/goal_classification/goal_classification.py
,其中依赖的基类的配置见代码
_base_ = [
'../_base_/models/goal_classification.py',
'../_base_/datasets/goal_classification.py',
'../_base_/schedules/imagenet_bs256_epochstep.py',
'../_base_/default_runtime.py'
]
从头训练¶
python tools/train.py configs/goal_classification/goal_classification.py
继续训练¶
python tools/train.py configs/goal_classification/goal_classification.py --resume-from work_dirs/goal_classification/goal_classification_mmcls_mobilenetv2_20211007+20211103_top197.8_20211104.pth
测试¶
python tools/test.py configs/goal_classification/goal_classification.py work_dirs/goal_classification/latest.pth --metrics accuracy
分析及绘图¶
参考 https://github.com/wojiazaiyugang/mmclassification/blob/master/docs/zh_CN/tools/analysis.md
模型转onnx¶
python tools/deployment/pytorch2onnx.py configs/goal_classification/goal_classification.py --simplify --checkpoint work_dirs/goal_classification/epoch_264.pth --output-file goal_classification_mmcls_mobilenetv2_20211007+20211103_top197.8_20211104.onnx