Skip to content

Commit

Permalink
Move radial basis functions into utils/rbf_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
bing-jian committed Jun 14, 2019
1 parent 1ba0546 commit b8d4b4b
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 129 deletions.
3 changes: 2 additions & 1 deletion C++/gmmreg_cpd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
#include <vnl/algo/vnl_svd.h>
#include <vnl/vnl_trace.h>

#include "gmmreg_utils.h"
//#include "gmmreg_utils.h"
#include "utils/em_utils.h"
#include "utils/io_utils.h"
#include "utils/rbf_utils.h"

namespace gmmreg {

Expand Down
8 changes: 3 additions & 5 deletions C++/gmmreg_grbf.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "gmmreg_grbf.h"

#include <assert.h>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>

Expand All @@ -11,9 +9,9 @@
#include <vnl/vnl_matrix.h>
#include <vnl/vnl_trace.h>

#include "gmmreg_utils.h"
#include "utils/io_utils.h"
#include "utils/misc_utils.h"
#include "utils/rbf_utils.h"

namespace gmmreg {

Expand Down Expand Up @@ -42,8 +40,8 @@ void GrbfRegistration::StartRegistration(vnl_vector<double>& params) {
SetParam(params);
int n_max_func_evals = v_func_evals_[k];
minimizer.set_max_function_evals(n_max_func_evals);
// For more options, see
// http://public.kitware.com/vxl/doc/release/core/vnl/html/vnl__nonlinear__minimizer_8h-source.html
// For more options, please see
// https://public.kitware.com/vxl/doc/release/core/vnl/html/vnl__nonlinear__minimizer_8h_source.html
minimizer.minimize(params);
if (minimizer.get_failure_code() < 0) {
break;
Expand Down
8 changes: 3 additions & 5 deletions C++/gmmreg_tps.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "gmmreg_tps.h"

#include <assert.h>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>

Expand All @@ -11,9 +9,9 @@
#include <vnl/vnl_matrix.h>
#include <vnl/vnl_trace.h>

#include "gmmreg_utils.h"
#include "utils/io_utils.h"
#include "utils/misc_utils.h"
#include "utils/rbf_utils.h"

namespace gmmreg {

Expand Down Expand Up @@ -41,8 +39,8 @@ void TpsRegistration::StartRegistration(vnl_vector<double>& params) {
SetParam(params);
int n_max_func_evals = v_func_evals_[k];
minimizer.set_max_function_evals(n_max_func_evals);
// For more options, see
// http://public.kitware.com/vxl/doc/release/core/vnl/html/vnl__nonlinear__minimizer_8h-source.html
// For more options, please see
// https://public.kitware.com/vxl/doc/release/core/vnl/html/vnl__nonlinear__minimizer_8h_source.html
minimizer.minimize(params);
if (minimizer.get_failure_code() < 0) {
break;
Expand Down
3 changes: 2 additions & 1 deletion C++/gmmreg_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
#include "port_ini.h"
#endif

#include "gmmreg_utils.h"
//#include "gmmreg_utils.h"
#include "utils/io_utils.h"
#include "utils/normalization_utils.h"
#include "utils/rbf_utils.h"

namespace gmmreg {

Expand Down
119 changes: 2 additions & 117 deletions C++/gmmreg_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <vnl/vnl_matrix.h>
#include <vnl/vnl_vector.h>

#include "utils/macros.h"

namespace gmmreg {

template <typename T>
Expand All @@ -13,17 +15,6 @@ template <typename T>
T GaussTransform(const vnl_matrix<T>& A, const vnl_matrix<T>& B, T scale,
vnl_matrix<T>& gradient);

template <typename T>
void ComputeTPSKernel(const vnl_matrix<T>& model, const vnl_matrix<T>& ctrl_pts,
vnl_matrix<T>& U, vnl_matrix<T>& K);

template <typename T>
void ComputeGaussianKernel(const vnl_matrix<T>& model,
const vnl_matrix<T>& ctrl_pts, vnl_matrix<T>& G,
vnl_matrix<T>& K, T beta);


#define SQR(X) ((X) * (X))

/*
* Note: The input point set containing 'n' points in 'd'-dimensional
Expand Down Expand Up @@ -51,24 +42,6 @@ T GaussTransform(const T* A, const T* B, int m, int n, int dim, T scale) {
return cross_term / (m * n);
}

template <typename T>
void GaussianAffinityMatrix(const T* A, const T* B, int m, int n, int dim,
T scale, T* dist) {
scale = -2.0 * SQR(scale);
int k = 0;
#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
T dist_ij = 0;
for (int d = 0; d < dim; ++d) {
dist_ij += SQR(A[i * dim + d] - B[j * dim + d]);
}
dist[k] = exp(dist_ij / scale);
++k;
}
}
}

template <typename T>
T GaussTransform(const T* A, const T* B, int m, int n, int dim, T scale,
T* grad) {
Expand Down Expand Up @@ -119,94 +92,6 @@ T GaussTransform(const vnl_matrix<T>& A, const vnl_matrix<T>& B, T scale,
A.cols(), scale, gradient.data_block());
}

// TODO: add one more version when the model is same as ctrl_pts
// reference: Landmark-based Image Analysis, Karl Rohr, p195
template <typename T>
void ComputeTPSKernel(const vnl_matrix<T>& model, const vnl_matrix<T>& ctrl_pts,
vnl_matrix<T>& U, vnl_matrix<T>& K) {
int m = model.rows();
int n = ctrl_pts.rows();
int d = ctrl_pts.cols();
// asssert(model.cols()==d==(2|3));
K.set_size(n, n);
K.fill(0);
U.set_size(m, n);
U.fill(0);
T eps = 1e-006;

#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
vnl_vector<T> v_ij = model.get_row(i) - ctrl_pts.get_row(j);
if (d == 2) {
T r = v_ij.squared_magnitude();
if (r > eps) {
U(i, j) = r * log(r) / 2;
}
} else if (d == 3) {
T r = v_ij.two_norm();
U(i, j) = -r;
}
}
}

#pragma omp for
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
vnl_vector<T> v_ij = ctrl_pts.get_row(i) - ctrl_pts.get_row(j);
if (d == 2) {
T r = v_ij.squared_magnitude();
if (r > eps) {
K(i, j) = r * log(r) / 2;
}
} else if (d == 3) {
T r = v_ij.two_norm();
K(i, j) = -r;
}
}
}

#pragma omp for
for (int i = 0; i < n; ++i) {
for (int j = 0; j < i; ++j) {
K(i, j) = K(j, i);
}
}
}

/*
Matlab code in cpd_G.m:
k=-2*beta^2;
[n, d]=size(x); [m, d]=size(y);
G=repmat(x,[1 1 m])-permute(repmat(y,[1 1 n]),[3 2 1]);
G=squeeze(sum(G.^2,2));
G=G/k;
G=exp(G);
*/
template <typename T>
void ComputeGaussianKernel(const vnl_matrix<T>& model,
const vnl_matrix<T>& ctrl_pts, vnl_matrix<T>& G,
vnl_matrix<T>& K, T lambda) {
int m, n, d;
m = model.rows();
n = ctrl_pts.rows();
d = ctrl_pts.cols();
// asssert(model.cols()==d);
// assert(lambda>0);

G.set_size(m, n);
GaussianAffinityMatrix(model.data_block(), ctrl_pts.data_block(), m, n, d,
lambda, G.data_block());

if (model == ctrl_pts) {
K = G;
} else {
K.set_size(n, n);
GaussianAffinityMatrix(ctrl_pts.data_block(), ctrl_pts.data_block(), n, n,
d, lambda, K.data_block());
}
}

} // namespace gmmreg

Expand Down
10 changes: 10 additions & 0 deletions C++/utils/macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef GMMREG_UTILS_MACROS_H_
#define GMMREG_UTILS_MACROS_H_

namespace gmmreg {

#define SQR(X) ((X) * (X))

} // namespace gmmreg

#endif // GMMREG_UTILS_MACROS_H_
114 changes: 114 additions & 0 deletions C++/utils/rbf_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "rbf_utils.h"

#include "macros.h"

namespace gmmreg {

template <typename T>
void GaussianAffinityMatrix(const T* A, const T* B, int m, int n, int dim,
T scale, T* dist) {
scale = -2.0 * SQR(scale);
int k = 0;
#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
T dist_ij = 0;
for (int d = 0; d < dim; ++d) {
dist_ij += SQR(A[i * dim + d] - B[j * dim + d]);
}
dist[k] = exp(dist_ij / scale);
++k;
}
}
}

// TODO: add one more version when the model is same as ctrl_pts
// reference: Landmark-based Image Analysis, Karl Rohr, p195
template <typename T>
void ComputeTPSKernel(const vnl_matrix<T>& model, const vnl_matrix<T>& ctrl_pts,
vnl_matrix<T>& U, vnl_matrix<T>& K) {
int m = model.rows();
int n = ctrl_pts.rows();
int d = ctrl_pts.cols();
// asssert(model.cols()==d==(2|3));
K.set_size(n, n);
K.fill(0);
U.set_size(m, n);
U.fill(0);
T eps = 1e-006;

#pragma omp for
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
vnl_vector<T> v_ij = model.get_row(i) - ctrl_pts.get_row(j);
if (d == 2) {
T r = v_ij.squared_magnitude();
if (r > eps) {
U(i, j) = r * log(r) / 2;
}
} else if (d == 3) {
T r = v_ij.two_norm();
U(i, j) = -r;
}
}
}

#pragma omp for
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
vnl_vector<T> v_ij = ctrl_pts.get_row(i) - ctrl_pts.get_row(j);
if (d == 2) {
T r = v_ij.squared_magnitude();
if (r > eps) {
K(i, j) = r * log(r) / 2;
}
} else if (d == 3) {
T r = v_ij.two_norm();
K(i, j) = -r;
}
}
}

#pragma omp for
for (int i = 0; i < n; ++i) {
for (int j = 0; j < i; ++j) {
K(i, j) = K(j, i);
}
}
}

/*
Matlab code in cpd_G.m:
k=-2*beta^2;
[n, d]=size(x); [m, d]=size(y);
G=repmat(x,[1 1 m])-permute(repmat(y,[1 1 n]),[3 2 1]);
G=squeeze(sum(G.^2,2));
G=G/k;
G=exp(G);
*/
template <typename T>
void ComputeGaussianKernel(const vnl_matrix<T>& model,
const vnl_matrix<T>& ctrl_pts, vnl_matrix<T>& G,
vnl_matrix<T>& K, T lambda) {
int m, n, d;
m = model.rows();
n = ctrl_pts.rows();
d = ctrl_pts.cols();
// asssert(model.cols()==d);
// assert(lambda>0);

G.set_size(m, n);
GaussianAffinityMatrix(model.data_block(), ctrl_pts.data_block(), m, n, d,
lambda, G.data_block());

if (model == ctrl_pts) {
K = G;
} else {
K.set_size(n, n);
GaussianAffinityMatrix(ctrl_pts.data_block(), ctrl_pts.data_block(), n, n,
d, lambda, K.data_block());
}
}

} // namespace gmmreg
25 changes: 25 additions & 0 deletions C++/utils/rbf_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef GMMREG_UTILS_RBF_UTILS_H_
#define GMMREG_UTILS_RBF_UTILS_H_

#include <vnl/vnl_matrix.h>

namespace gmmreg {

template <typename T>
void GaussianAffinityMatrix(const T* A, const T* B, int m, int n, int dim,
T scale, T* dist);

template <typename T>
void ComputeTPSKernel(const vnl_matrix<T>& model, const vnl_matrix<T>& ctrl_pts,
vnl_matrix<T>& U, vnl_matrix<T>& K);

template <typename T>
void ComputeGaussianKernel(const vnl_matrix<T>& model,
const vnl_matrix<T>& ctrl_pts, vnl_matrix<T>& G,
vnl_matrix<T>& K, T beta);

} // namespace gmmreg

#include "rbf_utils.cc"

#endif // GMMREG_UTILS_RBF_UTILS_H_

0 comments on commit b8d4b4b

Please sign in to comment.