OpenCV机器学习(10)训练数据的一个核心类cv::ml::TrainData
- 电脑硬件
- 2025-08-25 20:57:02

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C++11 算法描述
cv::ml::TrainData 类是 OpenCV 机器学习模块中用于表示训练数据的一个核心类。它封装了样本数据、响应(标签)、样本权重等信息,并提供了多种方法来创建和操作这些数据,以适应不同的机器学习算法需求。
主要功能 数据准备:允许你从原始数据创建训练数据对象。支持多种任务:无论是分类、回归还是其他类型的任务,都可以使用 TrainData 来组织你的数据。灵活的数据输入:支持直接从矩阵输入数据,也支持加载来自文件的数据。数据分割:可以将数据集分割为训练集和测试集。 常用成员函数 创建 TrainData 对象 static Ptr create(InputArray samples, int layout, InputArray responses, InputArray varIdx=noArray(), InputArray sampleIdx=noArray(), InputArray sampleWeights=noArray(), InputArray varType=noArray()): 从给定的样本、响应和其他可选参数创建一个 TrainData 对象。 samples:样本数据矩阵,每一行代表一个样本。layout:样本布局,可以是 ROW_SAMPLE 或 COL_SAMPLE,表示每个样本是按行还是按列存储。responses:每个样本对应的响应向量或矩阵。 获取数据信息 int getNTrainSamples() const:获取训练样本的数量。int getNVars() const:获取变量(特征)的数量。Mat getSamples() const:返回所有样本。Mat getResponses() const:返回所有响应。Mat getSampleWeights() const:返回样本权重。Mat getTrainSampleWeights() const:返回训练集的样本权重。 数据分割 void setTrainTestSplit(int count, bool shuffle=true):根据指定的训练样本数量将数据集划分为训练集和测试集。void setTrainTestSplitRatio(double ratio, bool shuffle=true):根据比例将数据集划分为训练集和测试集。Mat getTrainSamples() const:返回训练集的样本。Mat getTrainResponses() const:返回训练集的响应。Mat getTestSamples() const:返回测试集的样本。Mat getTestResponses() const:返回测试集的响应。 代码示例 #include <iostream> #include <opencv2/ml.hpp> #include <opencv2/opencv.hpp> using namespace cv; using namespace cv::ml; using namespace std; int main() { // 准备训练数据 Mat samples = ( Mat_< float >( 4, 2 ) << 0.5, 1.0, 1.0, 1.5, 2.0, 0.5, 1.5, 0.0 ); Mat responses = ( Mat_< int >( 4, 1 ) << 0, 0, 1, 1 ); // 使用TrainData创建训练数据对象 Ptr< TrainData > trainData = TrainData::create( samples, ROW_SAMPLE, responses ); // 打印样本数量和变量数量 cout << "Number of training samples: " << trainData->getNTrainSamples() << endl; cout << "Number of variables: " << trainData->getNVars() << endl; // 分割数据集为训练集和测试集 trainData->setTrainTestSplitRatio( 0.75, true ); // 按75%比例分割,shuffle=true表示随机打乱 // 获取训练样本和响应 Mat trainSamples = trainData->getTrainSamples(); Mat trainResponses = trainData->getTrainResponses(); // 获取测试样本和响应 Mat testSamples = trainData->getTestSamples(); Mat testResponses = trainData->getTestResponses(); // 训练一个简单的SVM模型作为示例 Ptr< SVM > svm_model = SVM::create(); svm_model->setType( SVM::C_SVC ); svm_model->setKernel( SVM::RBF ); svm_model->setC( 1 ); svm_model->setGamma( 0.5 ); bool ok = svm_model->train( trainData ); if ( ok ) { // 对测试集中的样本进行预测 float response = svm_model->predict( testSamples ); cout << "The predicted response for the test sample is: " << response << endl; } else { cerr << "Training failed!" << endl; } return 0; } 运行结果 Number of training samples: 4 Number of variables: 2 The predicted response for the test sample is: 1OpenCV机器学习(10)训练数据的一个核心类cv::ml::TrainData由讯客互联电脑硬件栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“OpenCV机器学习(10)训练数据的一个核心类cv::ml::TrainData”