Skip to content

Commit

Permalink
first implementation KNearest wrapper on KDTree
Browse files Browse the repository at this point in the history
  • Loading branch information
avdmitry committed Aug 23, 2014
1 parent 37b1a75 commit 9ddb23e
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 7 deletions.
8 changes: 6 additions & 2 deletions modules/ml/include/opencv2/ml.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,22 @@ class CV_EXPORTS_W KNearest : public StatModel
class CV_EXPORTS_W_MAP Params
{
public:
Params(int defaultK=10, bool isclassifier=true);
Params(int defaultK=10, bool isclassifier_=true, int Emax_=INT_MAX);

CV_PROP_RW int defaultK;
CV_PROP_RW bool isclassifier;
CV_PROP_RW int Emax; // for implementation with KDTree
};
virtual void setParams(const Params& p) = 0;
virtual Params getParams() const = 0;
virtual float findNearest( InputArray samples, int k,
OutputArray results,
OutputArray neighborResponses=noArray(),
OutputArray dist=noArray() ) const = 0;
static Ptr<KNearest> create(const Params& params=Params());

enum { DEFAULT=1, KDTREE=2 };

static Ptr<KNearest> create(const Params& params=Params(), int type=DEFAULT);
};

/****************************************************************************************\
Expand Down
153 changes: 151 additions & 2 deletions modules/ml/src/knearest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@
namespace cv {
namespace ml {

KNearest::Params::Params(int k, bool isclassifier_)
KNearest::Params::Params(int k, bool isclassifier_, int Emax_)
{
defaultK = k;
isclassifier = isclassifier_;
Emax = Emax_;
}


Expand Down Expand Up @@ -352,8 +353,156 @@ class KNearestImpl : public KNearest
Params params;
};

Ptr<KNearest> KNearest::create(const Params& p)

class KNearestKDTreeImpl : public KNearest
{
public:
KNearestKDTreeImpl(const Params& p)
{
params = p;
}

virtual ~KNearestKDTreeImpl() {}

Params getParams() const { return params; }
void setParams(const Params& p) { params = p; }

bool isClassifier() const { return params.isclassifier; }
bool isTrained() const { return !samples.empty(); }

String getDefaultModelName() const { return "opencv_ml_knn_kd"; }

void clear()
{
samples.release();
responses.release();
}

int getVarCount() const { return samples.cols; }

bool train( const Ptr<TrainData>& data, int flags )
{
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
Mat new_responses;
data->getTrainResponses().convertTo(new_responses, CV_32F);
bool update = (flags & UPDATE_MODEL) != 0 && !samples.empty();

CV_Assert( new_samples.type() == CV_32F );

if( !update )
{
clear();
}
else
{
CV_Assert( new_samples.cols == samples.cols &&
new_responses.cols == responses.cols );
}

samples.push_back(new_samples);
responses.push_back(new_responses);

tr.build(samples);

return true;
}

float findNearest( InputArray _samples, int k,
OutputArray _results,
OutputArray _neighborResponses,
OutputArray _dists ) const
{
float result = 0.f;
CV_Assert( 0 < k );

Mat test_samples = _samples.getMat();
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
int testcount = test_samples.rows;

if( testcount == 0 )
{
_results.release();
_neighborResponses.release();
_dists.release();
return 0.f;
}

Mat res, nr, d;
if( _results.needed() )
{
_results.create(testcount, 1, CV_32F);
res = _results.getMat();
}
if( _neighborResponses.needed() )
{
_neighborResponses.create(testcount, k, CV_32F);
nr = _neighborResponses.getMat();
}
if( _dists.needed() )
{
_dists.create(testcount, k, CV_32F);
d = _dists.getMat();
}

for (int i=0; i<test_samples.rows; ++i)
{
Mat _res, _nr, _d;
if (res.rows>i)
{
_res = res.row(i);
}
if (nr.rows>i)
{
_nr = nr.row(i);
}
if (d.rows>i)
{
_d = d.row(i);
}
tr.findNearest(test_samples.row(i), k, params.Emax, _res, _nr, _d, noArray());
}

return result; // currently always 0
}

float predict(InputArray inputs, OutputArray outputs, int) const
{
return findNearest( inputs, params.defaultK, outputs, noArray(), noArray() );
}

void write( FileStorage& fs ) const
{
fs << "is_classifier" << (int)params.isclassifier;
fs << "default_k" << params.defaultK;

fs << "samples" << samples;
fs << "responses" << responses;
}

void read( const FileNode& fn )
{
clear();
params.isclassifier = (int)fn["is_classifier"] != 0;
params.defaultK = (int)fn["default_k"];

fn["samples"] >> samples;
fn["responses"] >> responses;
}

KDTree tr;

Mat samples;
Mat responses;
Params params;
};

Ptr<KNearest> KNearest::create(const Params& p, int type)
{
if (KDTREE==type)
{
return makePtr<KNearestKDTreeImpl>(p);
}

return makePtr<KNearestImpl>(p);
}

Expand Down
19 changes: 16 additions & 3 deletions modules/ml/test/test_emknearestkmeans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,11 @@ void CV_KNearestTest::run( int /*start_from*/ )
generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );

int code = cvtest::TS::OK;
Ptr<KNearest> knearest = KNearest::create(true);
knearest->train(trainData, cv::ml::ROW_SAMPLE, trainLabels);
knearest->findNearest( testData, 4, bestLabels);

// KNearest default implementation
Ptr<KNearest> knearest = KNearest::create();
knearest->train(trainData, ml::ROW_SAMPLE, trainLabels);
knearest->findNearest(testData, 4, bestLabels);
float err;
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
{
Expand All @@ -326,6 +328,17 @@ void CV_KNearestTest::run( int /*start_from*/ )
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}

// KNearest KDTree implementation
Ptr<KNearest> knearestKdt = KNearest::create(ml::KNearest::Params(), ml::KNearest::KDTREE);
knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
knearestKdt->findNearest(testData, 4, bestLabels);
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
{
ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}

ts->set_failed_test_info( code );
}

Expand Down

0 comments on commit 9ddb23e

Please sign in to comment.