代码拉取完成,页面将自动刷新
修改文件:EasyPR/src/train/svm_train.cpp
修改函数:Svm::train
修改注释:add by sunjc for svm file output 2015-12-1
修改后的代码如下:
/**add by sunjc for svm file output 2015-12-1*/
#define OUT_SVM_PATH_LEN 128
/**end add by sunjc for svm file output 2015-12-1*/
void Svm::train(bool divide /* = true */, float divide_percentage /* = 0.7 */,
bool train /* = true */,
const char* out_svm_path /* = NULL */) {
/**add by sunjc for svm file output 2015-12-1*/
char out_svm_file[OUT_SVM_PATH_LEN]={0};
/**end add by sunjc for svm file output 2015-12-1*/
if (out_svm_path == NULL) {
out_svm_path = "resources/model";
}
/**add by sunjc for svm file output 2015-12-1*/
sprintf(out_svm_file,"%s/svm.xml",out_svm_path);
/**end add by sunjc for svm file output 2015-12-1*/
if (divide) {
std::cout << "Dividing data to be trained and tested..." << std::endl;
this->divide(forward_, divide_percentage);
this->divide(inverse_, divide_percentage);
}
CvSVM svm;
// 70% training procedure
if (train) {
this->get_train();
if (!this->classes_.empty() && !this->trainingData_.empty()) {
// need to be trained first
CvSVMParams SVM_params;
SVM_params.svm_type = CvSVM::C_SVC;
// SVM_params.kernel_type = CvSVM::LINEAR; //CvSVM::LINEAR;
// 线型,也就是无核
SVM_params.kernel_type =
CvSVM::RBF; // CvSVM::RBF 径向基函数,也就是高斯核
SVM_params.degree = 0.1;
SVM_params.gamma = 1;
SVM_params.coef0 = 0.1;
SVM_params.C = 1;
SVM_params.nu = 0.1;
SVM_params.p = 0.1;
SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100000, 0.0001);
std::cout << "Generating svm model file, please wait..." << std::endl;
try {
// CvSVM svm(trainingData, classes, cv::Mat(), cv::Mat(), SVM_params);
svm.train_auto(this->trainingData_, this->classes_, cv::Mat(),
cv::Mat(), SVM_params, 10,
CvSVM::get_default_grid(CvSVM::C),
CvSVM::get_default_grid(CvSVM::GAMMA),
CvSVM::get_default_grid(CvSVM::P),
CvSVM::get_default_grid(CvSVM::NU),
CvSVM::get_default_grid(CvSVM::COEF),
CvSVM::get_default_grid(CvSVM::DEGREE), true);
} catch (const cv::Exception& err) {
std::cout << err.what() << std::endl;
}
utils::mkdir(out_svm_path);
//cv::FileStorage fsTo(out_svm_path, cv::FileStorage::WRITE);
/**add by sunjc for svm file output 2015-12-1*/
cv::FileStorage fsTo(out_svm_file, cv::FileStorage::WRITE);
svm.write(*fsTo, "svm");
std::cout << "Generate done! The model file is located at "
<< out_svm_file << std::endl;
/**end add by sunjc for out svm file 2015-12-1*/
}
/**delete by sunjc for out svm file 2015-12-1*/
/*else {
// don't train, use ready-made model file
try {
svm.load("resources/train/svm.xml", "svm");
} catch (const cv::Exception& err) {
std::cout << err.what() << std::endl;
}
}
*/
/**end delete by sunjc for out svm file 2015-12-1*/
} // if train
// TODO Check whether the model file exists or not.
/**add by sunjc for svm file output 2015-12-1*/
//svm.load(out_svm_path, "svm"); // make sure svm model was loaded
try {
svm.load(out_svm_file, "svm");
} catch (const cv::Exception& err) {
std::cout << err.what() << std::endl;
}
/**end add by sunjc for svm file output 2015-12-1*/
// 30% testing procedure
this->get_test();
std::cout << "Testing..." << std::endl;
double count_all = test_imgaes_.size();
double ptrue_rtrue = 0;
double ptrue_rfalse = 0;
double pfalse_rtrue = 0;
double pfalse_rfalse = 0;
size_t label_index = 0;
for (auto image : test_imgaes_) {
//调用回调函数决定特征
auto features = easypr::histeq(image);
features = features.reshape(1, 1);
cv::Mat out;
features.convertTo(out, CV_32FC1);
Label predict = ((int)svm.predict(out)) == 1 ? kForward : kInverse;
Label real = test_labels_[label_index++];
if (predict == kForward && real == kForward) ptrue_rtrue++;
if (predict == kForward && real == kInverse) ptrue_rfalse++;
if (predict == kInverse && real == kForward) pfalse_rtrue++;
if (predict == kInverse && real == kInverse) pfalse_rfalse++;
}
std::cout << "count_all: " << count_all << std::endl;
std::cout << "ptrue_rtrue: " << ptrue_rtrue << std::endl;
std::cout << "ptrue_rfalse: " << ptrue_rfalse << std::endl;
std::cout << "pfalse_rtrue: " << pfalse_rtrue << std::endl;
std::cout << "pfalse_rfalse: " << pfalse_rfalse << std::endl;
double precise = 0;
if (ptrue_rtrue + ptrue_rfalse != 0) {
precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse);
std::cout << "precise: " << precise << std::endl;
} else {
std::cout << "precise: "
<< "NA" << std::endl;
}
double recall = 0;
if (ptrue_rtrue + pfalse_rtrue != 0) {
recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue);
std::cout << "recall: " << recall << std::endl;
} else {
std::cout << "recall: "
<< "NA" << std::endl;
}
double Fsocre = 0;
if (precise + recall != 0) {
Fsocre = 2 * (precise * recall) / (precise + recall);
std::cout << "Fsocre: " << Fsocre << std::endl;
} else {
std::cout << "Fsocre: "
<< "NA" << std::endl;
}
}
} // namespace easypr