Skip to content

Commit

Permalink
added support for compensating unbalanced training data by adding cla…
Browse files Browse the repository at this point in the history
…ss weights to LibSvmClassifier again (this time it is optional)

added factory functions for creating binary or one-class SVMs to LibSvmClassifier
  • Loading branch information
Peter Poschmann committed Nov 13, 2015
1 parent be4db09 commit ee15900
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 35 deletions.
4 changes: 2 additions & 2 deletions adaptiveTrackingApp/AdaptiveTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ unique_ptr<ExampleManagement> AdaptiveTracking::createExampleManagement(ptree& c

shared_ptr<TrainableSvmClassifier> AdaptiveTracking::createLibSvmClassifier(ptree& config, shared_ptr<Kernel> kernel) {
if (config.get_value<string>() == "binary") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("C"));
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createBinarySvm(kernel, config.get<double>("C"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
trainableSvm->setNegativeExampleManagement(
Expand All @@ -337,7 +337,7 @@ shared_ptr<TrainableSvmClassifier> AdaptiveTracking::createLibSvmClassifier(ptre
}
return trainableSvm;
} else if (config.get_value<string>() == "one-class") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("nu"), true);
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createOneClassSvm(kernel, config.get<double>("nu"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
return trainableSvm;
Expand Down
4 changes: 2 additions & 2 deletions benchmarkApp/BenchmarkRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ unique_ptr<ExampleManagement> createExampleManagement(ptree& config, shared_ptr<

shared_ptr<TrainableSvmClassifier> createLibSvmClassifier(ptree& config, shared_ptr<Kernel> kernel) {
if (config.get_value<string>() == "binary") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("C"));
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createBinarySvm(kernel, config.get<double>("C"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
trainableSvm->setNegativeExampleManagement(
Expand All @@ -353,7 +353,7 @@ shared_ptr<TrainableSvmClassifier> createLibSvmClassifier(ptree& config, shared_
}
return trainableSvm;
} else if (config.get_value<string>() == "one-class") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("nu"), true);
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createOneClassSvm(kernel, config.get<double>("nu"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
return trainableSvm;
Expand Down
4 changes: 2 additions & 2 deletions headTrackingApp/HeadTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ unique_ptr<ExampleManagement> HeadTracking::createExampleManagement(ptree& confi

shared_ptr<TrainableSvmClassifier> HeadTracking::createLibSvmClassifier(ptree& config, shared_ptr<Kernel> kernel) {
if (config.get_value<string>() == "binary") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("C"));
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createBinarySvm(kernel, config.get<double>("C"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
trainableSvm->setNegativeExampleManagement(
Expand All @@ -356,7 +356,7 @@ shared_ptr<TrainableSvmClassifier> HeadTracking::createLibSvmClassifier(ptree& c
}
return trainableSvm;
} else if (config.get_value<string>() == "one-class") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("nu"), true);
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createOneClassSvm(kernel, config.get<double>("nu"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
return trainableSvm;
Expand Down
60 changes: 53 additions & 7 deletions libSvm/include/libsvm/LibSvmClassifier.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,70 @@ class LibSvmClassifier : public classification::TrainableSvmClassifier {
public:

/**
* Constructs a new libSVM classifier.
* Creates a new one-class support vector machine that will be trained with libSVM.
*
* @param[in] kernel The kernel function.
* @param[in] nu The nu.
*/
static std::shared_ptr<LibSvmClassifier> createOneClassSvm(std::shared_ptr<classification::Kernel> kernel, double nu = 1) {
return std::make_shared<LibSvmClassifier>(kernel, nu, true, false);
}

/**
* Creates a new one-class support vector machine that will be trained with libSVM.
*
* @param[in] svm The actual SVM.
* @param[in] nu The nu.
*/
static std::shared_ptr<LibSvmClassifier> createOneClassSvm(std::shared_ptr<classification::SvmClassifier> svm, double nu = 1) {
return std::make_shared<LibSvmClassifier>(svm, nu, true, false);
}

/**
* Creates a new binary support vector machine that will be trained with libSVM.
*
* @param[in] kernel The kernel function.
* @param[in] c The C.
* @param[in] compensateImbalance Flag that indicates whether to adjust class weights to compensate for unbalanced data.
*/
static std::shared_ptr<LibSvmClassifier> createBinarySvm(
std::shared_ptr<classification::Kernel> kernel, double c = 1, bool compensateImbalance = false) {
return std::make_shared<LibSvmClassifier>(kernel, c, false, compensateImbalance);
}

/**
* Creates a new binary support vector machine that will be trained with libSVM.
*
* @param[in] svm The actual SVM.
* @param[in] c The C.
* @param[in] compensateImbalance Flag that indicates whether to adjust class weights to compensate for unbalanced data.
*/
static std::shared_ptr<LibSvmClassifier> createBinarySvm(
std::shared_ptr<classification::SvmClassifier> svm, double c = 1, bool compensateImbalance = false) {
return std::make_shared<LibSvmClassifier>(svm, c, false, compensateImbalance);
}

/**
* Constructs a new libSVM classifier.
*
* @param[in] kernel The kernel function.
* @param[in] cnu The parameter C in case of an ordinary SVM, nu in case of a one-class SVM.
* @param[in] oneClass Flag that indicates whether a one-class SVM should be trained.
* @param[in] compensateImbalance Flag that indicates whether to adjust class weights to compensate for unbalanced data.
*/
explicit LibSvmClassifier(
std::shared_ptr<classification::SvmClassifier> svm, double cnu = 1, bool oneClass = false);
LibSvmClassifier(std::shared_ptr<classification::Kernel> kernel, double cnu, bool oneClass, bool compensateImbalance = false);

/**
* Constructs a new libSVM classifier.
*
* @param[in] kernel The kernel function.
* @param[in] svm The actual SVM.
* @param[in] cnu The parameter C in case of an ordinary SVM, nu in case of a one-class SVM.
* @param[in] oneClass Flag that indicates whether a one-class SVM should be trained.
* @param[in] compensateImbalance Flag that indicates whether to adjust class weights to compensate for unbalanced data.
*/
explicit LibSvmClassifier(
std::shared_ptr<classification::Kernel> kernel, double cnu = 1, bool oneClass = false);
LibSvmClassifier(std::shared_ptr<classification::SvmClassifier> svm, double cnu, bool oneClass, bool compensateImbalance = false);

public:

/**
* Loads static negative training examples from a file.
Expand Down Expand Up @@ -110,7 +156,7 @@ class LibSvmClassifier : public classification::TrainableSvmClassifier {
const std::vector<std::unique_ptr<struct svm_node[], NodeDeleter>>& negativeExamples,
const std::vector<std::unique_ptr<struct svm_node[], NodeDeleter>>& staticNegativeExamples);

bool oneClass; ///< Flag that indicates whether a one-class SVM should be trained.
bool compensateImbalance; ///< Flag that indicates whether to adjust class weights to compensate for unbalanced data.
LibSvmUtils utils; ///< Utils for using libSVM.
std::unique_ptr<struct svm_parameter, ParameterDeleter> param; ///< The libSVM parameters.
std::unique_ptr<classification::ExampleManagement> positiveExamples; ///< Storage of positive training examples.
Expand Down
50 changes: 32 additions & 18 deletions libSvm/src/libsvm/LibSvmClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,22 @@ using std::invalid_argument;

namespace libsvm {

LibSvmClassifier::LibSvmClassifier(shared_ptr<SvmClassifier> svm, double cnu, bool oneClass) :
TrainableSvmClassifier(svm), oneClass(oneClass), utils(), param(),
positiveExamples(new UnlimitedExampleManagement()), negativeExamples(), staticNegativeExamples() {
if (oneClass)
negativeExamples.reset(new EmptyExampleManagement());
else
negativeExamples.reset(new UnlimitedExampleManagement());
createParameters(svm->getKernel(), cnu, oneClass);
}
LibSvmClassifier::LibSvmClassifier(shared_ptr<Kernel> kernel, double cnu, bool oneClass, bool compensateImbalance) :
LibSvmClassifier(make_shared<SvmClassifier>(kernel), cnu, oneClass, compensateImbalance) {}

LibSvmClassifier::LibSvmClassifier(shared_ptr<Kernel> kernel, double cnu, bool oneClass) :
TrainableSvmClassifier(kernel), oneClass(oneClass), utils(), param(),
positiveExamples(new UnlimitedExampleManagement()), negativeExamples(), staticNegativeExamples() {
LibSvmClassifier::LibSvmClassifier(shared_ptr<SvmClassifier> svm, double cnu, bool oneClass, bool compensateImbalance) :
TrainableSvmClassifier(svm),
compensateImbalance(compensateImbalance),
utils(),
param(),
positiveExamples(new UnlimitedExampleManagement()),
negativeExamples(new UnlimitedExampleManagement()),
staticNegativeExamples() {
if (oneClass && compensateImbalance)
throw invalid_argument("LibSvmClassifier: a one-class SVM cannot have unbalanced data it needs to compensate for");
if (oneClass)
negativeExamples.reset(new EmptyExampleManagement());
else
negativeExamples.reset(new UnlimitedExampleManagement());
createParameters(kernel, cnu, oneClass);
createParameters(svm->getKernel(), cnu, oneClass);
}

void LibSvmClassifier::createParameters(const shared_ptr<Kernel> kernel, double cnu, bool oneClass) {
Expand All @@ -61,9 +59,19 @@ void LibSvmClassifier::createParameters(const shared_ptr<Kernel> kernel, double
param->C = cnu;
param->svm_type = C_SVC;
}
param->nr_weight = 0;
param->weight_label = nullptr;
param->weight = nullptr;
if (compensateImbalance) {
param->nr_weight = 2;
param->weight_label = (int*)malloc(param->nr_weight * sizeof(int));
param->weight_label[0] = +1;
param->weight_label[1] = -1;
param->weight = (double*)malloc(param->nr_weight * sizeof(double));
param->weight[0] = 1;
param->weight[1] = 1;
} else {
param->nr_weight = 0;
param->weight_label = nullptr;
param->weight = nullptr;
}
param->shrinking = 0;
param->probability = 0;
param->gamma = 0; // necessary for kernels that do not use this parameter
Expand Down Expand Up @@ -116,6 +124,12 @@ bool LibSvmClassifier::retrain(const vector<Mat>& newPositiveExamples, const vec
bool LibSvmClassifier::train() {
vector<unique_ptr<struct svm_node[], NodeDeleter>> positiveExamples = move(createNodes(this->positiveExamples.get()));
vector<unique_ptr<struct svm_node[], NodeDeleter>> negativeExamples = move(createNodes(this->negativeExamples.get()));
if (compensateImbalance) {
double positiveCount = positiveExamples.size();
double negativeCount = negativeExamples.size() + staticNegativeExamples.size();
param->weight[0] = negativeCount / positiveCount;
param->weight[1] = positiveCount / negativeCount;
}
unique_ptr<struct svm_problem, ProblemDeleter> problem = move(createProblem(
positiveExamples, negativeExamples, staticNegativeExamples));
const char* message = svm_check_parameter(problem.get(), param.get());
Expand Down
4 changes: 2 additions & 2 deletions partiallyAdaptiveTrackingApp/PartiallyAdaptiveTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ shared_ptr<Kernel> PartiallyAdaptiveTracking::createKernel(ptree config) {
shared_ptr<TrainableSvmClassifier> PartiallyAdaptiveTracking::createTrainableSvm(shared_ptr<Kernel> kernel, ptree config) {
shared_ptr<LibSvmClassifier> trainableSvm;
if (config.get_value<string>() == "fixedsize") {
trainableSvm = make_shared<LibSvmClassifier>(
trainableSvm = LibSvmClassifier::createBinarySvm(
kernel, config.get<double>("constraintsViolationCosts"));
trainableSvm->setPositiveExampleManagement(unique_ptr<ExampleManagement>(
new ConfidenceBasedExampleManagement(trainableSvm, config.get<size_t>("positiveExamples"), config.get<size_t>("minPositiveExamples"))));
Expand All @@ -100,7 +100,7 @@ shared_ptr<TrainableSvmClassifier> PartiallyAdaptiveTracking::createTrainableSvm
} else if (config.get_value<string>() == "framebased") {
size_t frameLength = config.get<size_t>("frameLength");
size_t minExamples = round(frameLength * config.get<float>("minAvgSamples"));
trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("constraintsViolationCosts"));
trainableSvm = LibSvmClassifier::createBinarySvm(kernel, config.get<double>("constraintsViolationCosts"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(new FrameBasedExampleManagement(frameLength, minExamples)));
trainableSvm->setNegativeExampleManagement(
Expand Down
4 changes: 2 additions & 2 deletions trackingBenchmarkApp/TrackingBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ unique_ptr<ExampleManagement> TrackingBenchmark::createExampleManagement(ptree&

shared_ptr<TrainableSvmClassifier> TrackingBenchmark::createLibSvmClassifier(ptree& config, shared_ptr<Kernel> kernel) {
if (config.get_value<string>() == "binary") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("C"));
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createBinarySvm(kernel, config.get<double>("C"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
trainableSvm->setNegativeExampleManagement(
Expand All @@ -330,7 +330,7 @@ shared_ptr<TrainableSvmClassifier> TrackingBenchmark::createLibSvmClassifier(ptr
}
return trainableSvm;
} else if (config.get_value<string>() == "one-class") {
shared_ptr<LibSvmClassifier> trainableSvm = make_shared<LibSvmClassifier>(kernel, config.get<double>("nu"), true);
shared_ptr<LibSvmClassifier> trainableSvm = LibSvmClassifier::createOneClassSvm(kernel, config.get<double>("nu"));
trainableSvm->setPositiveExampleManagement(
unique_ptr<ExampleManagement>(createExampleManagement(config.get_child("positiveExamples"), trainableSvm, true)));
return trainableSvm;
Expand Down

0 comments on commit ee15900

Please sign in to comment.