opencv 调用 pytorch训练的resnet模型

使用OpenCV的DNN模块调用pytorch训练的分类模型,这里记录一下中间的流程,主要分为模型训练,模型转换和OpenCV调用三步。

一、训练二分类模型

准备二分类数据,直接使用torchvision.models中的resnet18网络,主要编写的地方是自定义数据类中的__getitem__,和网络最后一层。

  • __getitem__
    将同类数据放在不同文件夹下,编写Mydataset类,在__getitem__函数中增加数据增强。
class Mydataset(Dataset):
    ......
    def __getitem__(self, idx):
        # idx-[0->len(images)]
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x: Image.open(x).convert(‘RGB‘),
            transforms.Resize((int(self.resize), int(self.resize))),
            # transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
            # transforms.RandomRotation(15),
            # transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
             std=[0.229, 0.224, 0.225])
        ])

        img = tf(img)
        label = torch.tensor(label)
        return img, label
    ......
  • 修改网络最后一层
    依据类别,修改最后一层的输出,主要代码如下:
model = resnet18(pretrained=True)  # 比较好的 model
model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
  # [b, 512, 1, 1] -> [b, 512]
  torch.nn.Flatten(),
  nn.Linear(512, 2)).to(device)  # 添加全连接层

# x = torch.randn(2, 3, 224, 224)
# print(model(x).shape)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义迭代参数的算法
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

二、Pytorch模型转为ONNX模型

直接调用torch.onnx接口可将模型导出为ONNX格式,这里主要介绍验证导出模型是否正确

import torch
from torchvision import transforms
from PIL import Image
from torchvision.models import resnet18
import torch.nn as nn
import torch.onnx
import onnx
import onnxruntime
import numpy as np

torch_model = "./resnet18-2Class.pkl"
onnx_save_path = "./resnet18-2Class.onnx"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.randn(1, 3, 224, 224, dtype=torch.float, device=device)
model = resnet18(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1],  # [b, 512, 1, 1] -> 接全连接层
  nn.Flatten(),  # [b, 512, 1, 1] -> [b, 512]
  nn.Linear(512, 2)).to(device)
model.load_state_dict(torch.load(torch_model))
model.eval()

print("Start convert model to onnx...")
torch.onnx.export(model,
                  data,
                  onnx_save_path,
                  opset_version=10,
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=["input"],  # 输入名
                  output_names=["output"],  # 输出名
                  dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
            "output": {0: "batch_size"}}
)

print("convert onnx is Done!")


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def get_test_transform():
    tf = transforms.Compose([
        lambda x: Image.open(x).convert(‘RGB‘),
        transforms.Resize((224, 224)),
        # transforms.CenterCrop(self.resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225])
    ])

    return tf


img_path = "./1.png"
img = get_test_transform()(img_path)
img = img.unsqueeze(0)  # --> NCHW
print("input img mean {} and std {}".format(img.mean(), img.std()))

torch_out = model(img.to(device))
print("torch predict: ", torch_out)

# onnx
resnet_session = onnxruntime.InferenceSession(onnx_save_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
onnx_out = resnet_session.run(None, inputs)[0]
print("onnx predict: ", onnx_out)

三、OpenCV调用ONNX模型进行分类

这里主要工作是对数据进行预处理,在第一部分中的__getitem__函数的增强部分,转为openCV图像处理如下,其他直接调用dnn模块下的readNetFromONNX(modelPath)即可。

cv::Mat img = cv::imread(imgPath);
img.convertTo(img, CV_32FC3);
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
cv::resize(img, img, cv::Size(224, 224));
img = img / 255.0;
std::vector<float> mean_value{ 0.485, 0.456, 0.406 };
std::vector<float> std_value{ 0.229, 0.224, 0.225 };
cv::Mat dst;
std::vector<cv::Mat> rgbChannels(3);
cv::split(img, rgbChannels);
for (auto i = 0; i < rgbChannels.size(); i++)
{
    rgbChannels[i] = (rgbChannels[i] - mean_value[i]) / std_value[i];
}
cv::merge(rgbChannels, dst);

其中有一个注意点,就是同一张图片用torchvision.transforms中的Resize()和OpenCV的resize()函数处理的结果会有一点差别,这是因为transforms中默认使用的PIL的resize进行处理,除了默认的双线性插值,还会进行antialiasing,不过这个对于分类任务影响并不太大。

opencv 调用 pytorch训练的resnet模型

[db:回答]

以上是opencv 调用 pytorch训练的resnet模型的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>