博客
关于我
opencv SVM分类Demo
阅读量:791 次
发布时间:2023-02-23

本文共 3788 字,大约阅读时间需要 12 分钟。

使用SVM进行图像分类训练与预测

项目背景

本项目利用支持向量机(SVM)进行图像分类任务,旨在训练一个能够区分不同类别的模型,并对新数据进行分类预测。以下是实现过程及相关代码解释。

1. 训练阶段(train.cpp)

1.1 代码结构

#include 
#include
#include
#include
#include
#include
#define _CRT_SECURE_NO_WARNINGSusing namespace std;using namespace cv;void getFiles(string path, vector
& files) { // 该函数用于获取指定路径下的所有文件名,支持子目录递归 // 详细实现见文档}void get_1(Mat& trainingImages, vector
& trainingLabels) { // 获取类别1的训练数据 // 1. 定义文件路径 char * filePath = "data\\train_image\\1"; // 2. 获取所有图片文件 vector
files; getFiles(filePath, files); // 3. 遍历每一张图片 int number = files.size(); for (int i = 0; i < number; i++) { // 读取图片并调整大小 Mat SrcImage = imread(files[i].c_str()); resize(SrcImage, SrcImage, cv::Size(60, 256), (0, 0), (0, 0), cv::INTER_LINEAR); // 调整为单通道矩阵 SrcImage = SrcImage.reshape(1, 1); // 添加到训练数据中 trainingImages.push_back(SrcImage); trainingLabels.push_back(1); }}void get_0(Mat& trainingImages, vector
& trainingLabels) { // 获取类别0的训练数据 // 代码逻辑与get_1类似 char * filePath = "data\\train_image\\0"; vector
files; getFiles(filePath, files); int number = files.size(); for (int i = 0; i < number; i++) { Mat SrcImage = imread(files[i].c_str()); resize(SrcImage, SrcImage, cv::Size(60, 256), (0, 0), (0, 0), cv::INTER_LINEAR); SrcImage = SrcImage.reshape(1, 1); trainingImages.push_back(SrcImage); trainingLabels.push_back(0); }}int main() { // 1. 定义训练数据矩阵 Mat classes; Mat trainingData; // 2. 获取训练数据 vector
trainingLabels; get_1(trainingImages, trainingLabels); get_0(trainingImages, trainingLabels); // 3. 将训练数据转换为OpenCV矩阵格式 trainingData = trainingImages; trainingData.convertTo(CV_32FC1); // 4. 拼接标签数据 classes = trainingLabels; // 5. 配置SVM参数 CvSVMParams SVM_params; SVM_params.svm_type = CvSVM::C_SVC; SVM_params.kernel_type = CvSVM::LINEAR; SVM_params.degree = 0; SVM_params.gamma = 1; SVM_params.coef0 = 0; SVM_params.C = 1; SVM_params.nu = 0; SVM_params.p = 0; SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01); // 6. 训练SVM模型 CvSVM svm; svm.train(trainingData, classes, Mat(), Mat(), SVM_params); // 7. 保存训练结果 svm.save("svm.xml"); cout << "训练好了!!!" << endl; getchar(); return 0;}

1.2 代码解释

  • getFiles 函数:用于递归获取指定路径下的所有文件名,支持子目录遍历。
  • get_1get_0 函数:分别用于获取类别1和类别0的训练数据,读取图片文件并调整大小后添加到训练数据集中。
  • main 函数:负责整个训练流程的配置与执行,包括数据准备、模型训练及结果保存。

2. 预测阶段(test.cpp)

2.1 代码结构

#include 
#include
#include
#include
#include
#include
#include
#define _CRT_SECURE_NO_WARNINGSusing namespace std;using namespace cv;void getFiles(string path, vector
& files) { // 该函数与训练阶段中的getFiles函数一致 // 详细实现见文档}int main() { // 1. 定义文件路径 char * filePath = "data\\test_image\\0"; vector
files; getFiles(filePath, files); int number = files.size(); // 2. 初始化SVM模型 CvSVM svm; svm.clear(); string modelpath = "svm.xml"; // 3. 加载预训练模型 FileStorage svm_fs(modelpath, FileStorage::READ); if (!svm_fs.isOpened()) { cout << "导入模型错误" << endl; return 0; } svm.load(modelpath.c_str()); // 4. 遍历测试图片 for (int i = 0; i < number; i++) { Mat inMat = imread(files[i].c_str()); resize(inMat, inMat, cv::Size(60, 256), (0, 0), (0, 0), cv::INTER_LINEAR); Mat p = inMat.reshape(1, 1); p.convertTo(p, CV_32FC1); int response = (int)svm.predict(p); if (response == 0) { result++; } } // 5. 输出预测结果 cout << "识别个数:" << result << endl; cout << "识别率:"; cout << fixed << setprecision(2) << (double)result / (double)number << endl; getchar(); return 0;}

2.2 代码解释

  • getFiles 函数:与训练阶段中的函数一致,用于获取指定路径下的所有文件名。
  • main 函数:负责整个预测流程的执行,包括模型加载、文件遍历及结果输出。
  • 输出结果:显示测试图片的识别个数及识别率,方便用户验证模型性能。

3. 注意事项

  • 图片尺寸一致性:在训练与预测阶段,图片大小必须保持一致,否则会导致OpenCV函数错误。
  • 文件路径调整:确保训练与测试数据集的文件路径正确,避免路径错误。
  • 模型优化:根据实际任务需求,调整SVM的超参数(如Cgamma等)以优化分类性能。

4. 总结

本项目通过OpenCV+SVM实现了图像分类的基本功能,代码结构清晰,易于维护与扩展。训练阶段完成模型训练,预测阶段实现了对新数据的分类识别。

转载地址:http://flsfk.baihongyu.com/

你可能感兴趣的文章
onScrollStateChanged无效
查看>>
onTouchEvent构造器
查看>>
on_member_join 和删除不起作用.如何让它发挥作用?
查看>>
oobbs开发手记
查看>>
OOM怎么办,教你生成dump文件以及查看(IT枫斗者)
查看>>
OOP
查看>>
OOP之单例模式
查看>>
OOP向AOP思想的延伸
查看>>
Vue element 动态添加表单验证
查看>>
OO第一次blog
查看>>
OO第四单元总结
查看>>
OO第四次博客作业
查看>>
OO面向对象编程:第三单元总结
查看>>
Opacity多浏览器透明度兼容处理
查看>>
OPC在工控上位机中的应用
查看>>
VSCode在终端中使用yarn命令
查看>>
OPEN CASCADE Curve Continuity
查看>>
Open Graph Protocol(开放内容协议)
查看>>
Open vSwitch实验常用命令
查看>>
Open WebUI 忘了登入密码怎么办?
查看>>