1.2K Star 4.1K Fork 1.6K

GVPliuruoze / EasyPR

 / 详情

在svm训练的时候,老是有异常抛出,最后发现是在svm训练的时候需要手动创建.svm的子目录,感觉比较麻烦而且不自动,对代码进行修改里一下,提一个思路,望版主改善。

待办的
创建于  
2015-12-04 11:09

修改文件:EasyPR/src/train/svm_train.cpp

修改函数:Svm::train

修改注释:add by sunjc for svm file output 2015-12-1

修改后的代码如下:



/**add by sunjc  for svm file output 2015-12-1*/
#define OUT_SVM_PATH_LEN 128
/**end add by sunjc  for svm file output 2015-12-1*/
void Svm::train(bool divide /* = true */, float divide_percentage /* = 0.7 */,
                bool train /* = true */,
                const char* out_svm_path /* = NULL */) {
  /**add by sunjc  for svm file output 2015-12-1*/
  char out_svm_file[OUT_SVM_PATH_LEN]={0};
  /**end add by sunjc  for svm file output 2015-12-1*/

  if (out_svm_path == NULL) {
    out_svm_path = "resources/model";
  }
  /**add by sunjc  for svm file output 2015-12-1*/
  sprintf(out_svm_file,"%s/svm.xml",out_svm_path);
  /**end add by sunjc  for svm file output 2015-12-1*/
  if (divide) {
    std::cout << "Dividing data to be trained and tested..." << std::endl;
    this->divide(forward_, divide_percentage);
    this->divide(inverse_, divide_percentage);
  }

  CvSVM svm;

  // 70% training procedure
  if (train) {
    this->get_train();

    if (!this->classes_.empty() && !this->trainingData_.empty()) {
      // need to be trained first
      CvSVMParams SVM_params;
      SVM_params.svm_type = CvSVM::C_SVC;
      // SVM_params.kernel_type = CvSVM::LINEAR; //CvSVM::LINEAR;
      // 线型,也就是无核
      SVM_params.kernel_type =
          CvSVM::RBF;  // CvSVM::RBF 径向基函数,也就是高斯核
      SVM_params.degree = 0.1;
      SVM_params.gamma = 1;
      SVM_params.coef0 = 0.1;
      SVM_params.C = 1;
      SVM_params.nu = 0.1;
      SVM_params.p = 0.1;
      SVM_params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100000, 0.0001);

      std::cout << "Generating svm model file, please wait..." << std::endl;

      try {
        // CvSVM svm(trainingData, classes, cv::Mat(), cv::Mat(), SVM_params);
        svm.train_auto(this->trainingData_, this->classes_, cv::Mat(),
                       cv::Mat(), SVM_params, 10,
                       CvSVM::get_default_grid(CvSVM::C),
                       CvSVM::get_default_grid(CvSVM::GAMMA),
                       CvSVM::get_default_grid(CvSVM::P),
                       CvSVM::get_default_grid(CvSVM::NU),
                       CvSVM::get_default_grid(CvSVM::COEF),
                       CvSVM::get_default_grid(CvSVM::DEGREE), true);
      } catch (const cv::Exception& err) {
        std::cout << err.what() << std::endl;
      }

      utils::mkdir(out_svm_path);
      //cv::FileStorage fsTo(out_svm_path, cv::FileStorage::WRITE);
	 /**add by sunjc  for svm file output 2015-12-1*/
      cv::FileStorage fsTo(out_svm_file, cv::FileStorage::WRITE);
      svm.write(*fsTo, "svm");

      std::cout << "Generate done! The model file is located at "
                << out_svm_file << std::endl;

	  /**end add by sunjc for out svm file 2015-12-1*/
    } 
	/**delete by sunjc for out svm file 2015-12-1*/
	/*else {
      // don't train, use ready-made model file
      try {
        svm.load("resources/train/svm.xml", "svm");
      } catch (const cv::Exception& err) {
        std::cout << err.what() << std::endl;
      }
    }
	*/
	/**end delete by sunjc for out svm file 2015-12-1*/
  }  // if train

  // TODO Check whether the model file exists or not.
  /**add by sunjc  for svm file output 2015-12-1*/
  //svm.load(out_svm_path, "svm");  // make sure svm model was loaded
  try {
	  svm.load(out_svm_file, "svm");
  } catch (const cv::Exception& err) {
	  std::cout << err.what() << std::endl;
  }
  /**end add by sunjc  for svm file output 2015-12-1*/
  // 30% testing procedure
  this->get_test();

  std::cout << "Testing..." << std::endl;

  double count_all = test_imgaes_.size();
  double ptrue_rtrue = 0;
  double ptrue_rfalse = 0;
  double pfalse_rtrue = 0;
  double pfalse_rfalse = 0;

  size_t label_index = 0;
  for (auto image : test_imgaes_) {
    //调用回调函数决定特征
    auto features = easypr::histeq(image);
    features = features.reshape(1, 1);
    cv::Mat out;
    features.convertTo(out, CV_32FC1);

    Label predict = ((int)svm.predict(out)) == 1 ? kForward : kInverse;
    Label real = test_labels_[label_index++];

    if (predict == kForward && real == kForward) ptrue_rtrue++;
    if (predict == kForward && real == kInverse) ptrue_rfalse++;
    if (predict == kInverse && real == kForward) pfalse_rtrue++;
    if (predict == kInverse && real == kInverse) pfalse_rfalse++;
  }

  std::cout << "count_all: " << count_all << std::endl;
  std::cout << "ptrue_rtrue: " << ptrue_rtrue << std::endl;
  std::cout << "ptrue_rfalse: " << ptrue_rfalse << std::endl;
  std::cout << "pfalse_rtrue: " << pfalse_rtrue << std::endl;
  std::cout << "pfalse_rfalse: " << pfalse_rfalse << std::endl;

  double precise = 0;
  if (ptrue_rtrue + ptrue_rfalse != 0) {
    precise = ptrue_rtrue / (ptrue_rtrue + ptrue_rfalse);
    std::cout << "precise: " << precise << std::endl;
  } else {
    std::cout << "precise: "
              << "NA" << std::endl;
  }

  double recall = 0;
  if (ptrue_rtrue + pfalse_rtrue != 0) {
    recall = ptrue_rtrue / (ptrue_rtrue + pfalse_rtrue);
    std::cout << "recall: " << recall << std::endl;
  } else {
    std::cout << "recall: "
              << "NA" << std::endl;
  }

  double Fsocre = 0;
  if (precise + recall != 0) {
    Fsocre = 2 * (precise * recall) / (precise + recall);
    std::cout << "Fsocre: " << Fsocre << std::endl;
  } else {
    std::cout << "Fsocre: "
              << "NA" << std::endl;
  }
}

}  // namespace easypr

评论 (0)

登录 后才可以发表评论

状态
负责人
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
参与者(1)
C++
1
https://gitee.com/liuruoze/EasyPR.git
git@gitee.com:liuruoze/EasyPR.git
liuruoze
EasyPR
EasyPR

搜索帮助