Skip to content

Commit

Permalink
Move ComputeP from gmmreg_utils to utils/em_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
bing-jian committed Jun 14, 2019
1 parent c50b51d commit 0f22c7e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 51 deletions.
6 changes: 1 addition & 5 deletions C++/gmmreg_cpd.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
#include "gmmreg_cpd.h"

#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vnl/algo/vnl_determinant.h>
#include <vnl/algo/vnl_qr.h>
#include <vnl/algo/vnl_svd.h>
#include <vnl/vnl_trace.h>

#include "gmmreg_utils.h"
#include "utils/em_utils.h"
#include "utils/io_utils.h"

namespace gmmreg {
Expand Down
46 changes: 0 additions & 46 deletions C++/gmmreg_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ template <typename T>
void Denormalize(vnl_matrix<T>& x, const vnl_vector<T>& centroid,
const T scale);

template <typename T>
void ComputeP(const vnl_matrix<T>& x, const vnl_matrix<T>& y, vnl_matrix<T>& P,
T& E, T sigma, int outliers);

#define SQR(X) ((X) * (X))

Expand Down Expand Up @@ -273,49 +270,6 @@ void Denormalize(vnl_matrix<T>& x, const vnl_vector<T>& centroid,
}
}

template <typename T>
void ComputeP(const vnl_matrix<T>& x, const vnl_matrix<T>& y, vnl_matrix<T>& P,
T& E, T sigma, int outliers) {
T k;
k = -2 * sigma * sigma;

vnl_vector<T> column_sum;
int m = x.rows();
int s = y.rows();
int d = x.cols();
column_sum.set_size(s);
column_sum.fill(0);
T outlier_term = outliers * pow((2 * sigma * sigma * 3.1415926), 0.5 * d);
#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < s; ++j) {
T r = 0;
for (int t = 0; t < d; ++t) {
r += (x(i, t) - y(j, t)) * (x(i, t) - y(j, t));
}
P(i, j) = exp(r / k);
column_sum[j] += P(i, j);
}
}

if (outliers != 0) {
#pragma omp for
for (int i = 0; i < s; ++i) column_sum[i] += outlier_term;
}
if (column_sum.min_value() > (1e-12)) {
E = 0;
#pragma omp for
for (int i = 0; i < s; ++i) {
for (int j = 0; j < m; ++j) {
P(j, i) = P(j, i) / column_sum[i];
}
E -= log(column_sum[i]);
}
} else {
P.empty();
}
}

} // namespace gmmreg

#endif // GMMREG_UTILS_H_
48 changes: 48 additions & 0 deletions C++/utils/em_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "em_utils.h"

namespace gmmreg {

template <typename T>
void ComputeP(const vnl_matrix<T>& x, const vnl_matrix<T>& y, vnl_matrix<T>& P,
T& E, T sigma, int outliers) {
T k;
k = -2 * sigma * sigma;

vnl_vector<T> column_sum;
int m = x.rows();
int s = y.rows();
int d = x.cols();
column_sum.set_size(s);
column_sum.fill(0);
T outlier_term = outliers * pow((2 * sigma * sigma * 3.1415926), 0.5 * d);
#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < s; ++j) {
T r = 0;
for (int t = 0; t < d; ++t) {
r += (x(i, t) - y(j, t)) * (x(i, t) - y(j, t));
}
P(i, j) = exp(r / k);
column_sum[j] += P(i, j);
}
}

if (outliers != 0) {
#pragma omp for
for (int i = 0; i < s; ++i) column_sum[i] += outlier_term;
}
if (column_sum.min_value() > (1e-12)) {
E = 0;
#pragma omp for
for (int i = 0; i < s; ++i) {
for (int j = 0; j < m; ++j) {
P(j, i) = P(j, i) / column_sum[i];
}
E -= log(column_sum[i]);
}
} else {
P.empty();
}
}

} // namespace gmmreg
16 changes: 16 additions & 0 deletions C++/utils/em_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef GMMREG_UTILS_EM_UTILS_H_
#define GMMREG_UTILS_EM_UTILS_H_

#include <vnl/vnl_matrix.h>

namespace gmmreg {

template <typename T>
void ComputeP(const vnl_matrix<T>& x, const vnl_matrix<T>& y, vnl_matrix<T>& P,
T& E, T sigma, int outliers);

} // namespace gmmreg

#include "em_utils.cc"

#endif // GMMREG_UTILS_EM_UTILS_H_

0 comments on commit 0f22c7e

Please sign in to comment.