Skip to content

Commit

Permalink
finishing emd part.
Browse files Browse the repository at this point in the history
  • Loading branch information
william committed Apr 16, 2024
1 parent 0210a53 commit b7d9645
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 32 deletions.
166 changes: 157 additions & 9 deletions ot/ot.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ot.h"
#include "EMD.h"
using std::vector, std::pair;

namespace ot {
std::string check_result(int result_code) {
Expand All @@ -17,6 +18,39 @@ namespace ot {
return message;
}

vector<pair<int, int>> where(const ot::RowMajorMatrixXd& M, double thr) {
vector<pair<int, int>> indices;

for (int i = 0; i < M.rows(); ++i) {
for (int j = 0; j < M.cols(); ++j) {
if (M.coeffRef(i, j) <= thr) {
indices.emplace_back(i, j);
}
}
}

return indices;
}

vector<int> where(const Eigen::ArrayXd& A, double thr) {
vector<int> indices;

for (int i = 0; i < A.size(); ++i) {
if (A(i) <= thr) {
indices.emplace_back(i);
}
}

return indices;
}

void indexing_op(ot::RowMajorMatrixXd& M,
const vector<pair<int, int>>& indices, double v) {
for (const auto& index : indices) {
M(index.first, index.second) = v;
}
}

EMDCluster emd_c(Eigen::ArrayXd a, Eigen::ArrayXd b, RowMajorMatrixXd M,
uint64_t max_iter, int numThreads) {
/**
Expand Down Expand Up @@ -156,8 +190,54 @@ namespace ot {
return dataset;
}

void center_ot_dual(Eigen::ArrayXd& alpha0, Eigen::ArrayXd& beta0,
Eigen::ArrayXd& a, Eigen::ArrayXd& b) {
// void center_ot_dual(Eigen::ArrayXd& alpha0, Eigen::ArrayXd& beta0,
// Eigen::ArrayXd& a, Eigen::ArrayXd& b) {
// /**
// The main idea of this function is to find unique dual potentials
// that ensure some kind of centering/fairness. The main idea is to find
// dual potentials that lead to the same final objective value for both
// source and targets (see below for more details). It will help having
// stability when multiple calling of the OT solver with small changes.

// Basically we add another constraint to the potential that will not
// change the objective value but will ensure unicity. The constraint
// is the following:

// .. math::
// \alpha^T \mathbf{a} = \beta^T \mathbf{b}

// in addition to the OT problem constraints.

// since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
// a constant from both :math:`\alpha_0` and :math:`\beta_0`.

// .. math::
// c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T
// \mathbf{b} + \mathbf{1}^T \mathbf{a}}

// \alpha &= \alpha_0 + c

// \beta &= \beta_0 + c
// */
// if (a.size() == 0) {
// a = Eigen::ArrayXd::Ones(alpha0.size()) / alpha0.size();
// }

// if (b.size() == 0) {
// b = Eigen::ArrayXd::Ones(beta0.size()) / beta0.size();
// }

// auto c =
// (b.matrix().dot(beta0.matrix()) - a.matrix().dot(alpha0.matrix())) /
// (a.sum() + b.sum());

// alpha0 += c;
// beta0 -= c;
// }

AlphaBetaCrater center_ot_dual(const Eigen::ArrayXd& alpha0,
const Eigen::ArrayXd& beta0, Eigen::ArrayXd& a,
Eigen::ArrayXd& b) {
/**
The main idea of this function is to find unique dual potentials
that ensure some kind of centering/fairness. The main idea is to find dual
Expand Down Expand Up @@ -185,6 +265,8 @@ namespace ot {
\beta &= \beta_0 + c
*/
AlphaBetaCrater AB;

if (a.size() == 0) {
a = Eigen::ArrayXd::Ones(alpha0.size()) / alpha0.size();
}
Expand All @@ -197,14 +279,16 @@ namespace ot {
(b.matrix().dot(beta0.matrix()) - a.matrix().dot(alpha0.matrix())) /
(a.sum() + b.sum());

alpha0 += c;
beta0 -= c;
AB.alpha = alpha0 + c;
AB.beta = beta0 - c;
return AB;
}

void estimate_dual_null_weights(Eigen::ArrayXd& alpha0, Eigen::ArrayXd& beta0,
const Eigen::ArrayXd& a,
const Eigen::ArrayXd& b,
const RowMajorMatrixXd& M) {
AlphaBetaCrater estimate_dual_null_weights(Eigen::ArrayXd& alpha0,
Eigen::ArrayXd& beta0,
Eigen::ArrayXd& a,
Eigen::ArrayXd& b,
const RowMajorMatrixXd& M) {
/**Estimate feasible values for 0-weighted dual potentials
The feasible values are computed efficiently but rather coarsely.
Expand Down Expand Up @@ -245,21 +329,85 @@ namespace ot {
Note that all those updates do not change the objective value of the
solution but provide dual potentials that do not violate the constraints.
*/

auto adel_ = where(a, 0);
auto bdel_ = where(b, 0);

Eigen::ArrayXd adel =
Eigen::Map<Eigen::ArrayXi>(adel_.data(), adel_.size()).cast<double>();
Eigen::ArrayXd bdel =
Eigen::Map<Eigen::ArrayXi>(bdel_.data(), bdel_.size()).cast<double>();

Eigen::Map<RowMajorMatrixXd> alpha0_2d(alpha0.data(), alpha0.size(), 1);
Eigen::Map<RowMajorMatrixXd> beta0_2d(beta0.data(), beta0.size(), 1);

auto constraint_violation = alpha0_2d + beta0_2d - M;

auto aviol = constraint_violation.rowwise().maxCoeff().eval();
auto bviol = constraint_violation.colwise().maxCoeff().eval();
aviol.resize(adel.size());
auto alpha_up = -1 * adel * aviol.cwiseMax(0).eval().array();
auto beta_up = -1 * bdel * bviol.cwiseMax(0).eval().array();

return center_ot_dual(alpha_up.eval(), beta_up.eval(), a, b);
}

RowMajorMatrixXd emd(Eigen::ArrayXd& a, Eigen::ArrayXd& b,
const RowMajorMatrixXd& M, uint64_t numIterMax,
int numThreads, bool center_dual) {
b = b * a.sum() / b.sum();

auto adel = where(a, 0);
auto bdel = where(b, 0);

EMDCluster crater;
crater = emd_c(a, b, M, numIterMax, numThreads);

AlphaBetaCrater AB;
if (center_dual) {
center_ot_dual(crater.alpha, crater.beta, a, b);
AB = center_ot_dual(crater.alpha, crater.beta, a, b);
// update parameters
crater.alpha = AB.alpha;
crater.beta = AB.beta;
}

if (!adel.empty() || !bdel.empty()) {
AB = estimate_dual_null_weights(crater.alpha, crater.beta, a, b, M);
// update parameters
crater.alpha = AB.alpha;
crater.beta = AB.beta;
}

return crater.G;
}

EMDCluster emd_full(Eigen::ArrayXd& a, Eigen::ArrayXd& b,
const RowMajorMatrixXd& M, uint64_t numIterMax,
int numThreads, bool center_dual) {
b = b * a.sum() / b.sum();

auto adel = where(a, 0);
auto bdel = where(b, 0);

EMDCluster crater;
crater = emd_c(a, b, M, numIterMax, numThreads);

AlphaBetaCrater AB;
if (center_dual) {
AB = center_ot_dual(crater.alpha, crater.beta, a, b);
// update parameters
crater.alpha = AB.alpha;
crater.beta = AB.beta;
}

if (!adel.empty() || !bdel.empty()) {
AB = estimate_dual_null_weights(crater.alpha, crater.beta, a, b, M);
// update parameters
crater.alpha = AB.alpha;
crater.beta = AB.beta;
}

return crater;
}

} // namespace ot
9 changes: 9 additions & 0 deletions ot/ot.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ namespace ot {
double cost;
};

struct AlphaBetaCrater {
Eigen::ArrayXd alpha;
Eigen::ArrayXd beta;
};

std::string check_result(int result_code);
std::vector<std::pair<int, int>> where(const ot::RowMajorMatrixXd& M,
double thr);
void indexing_op(ot::RowMajorMatrixXd& M,
const std::vector<std::pair<int, int>>& indices, double v);
EMDCluster emd_c(Eigen::ArrayXd a, Eigen::ArrayXd b, RowMajorMatrixXd M,
uint64_t max_iter, int numThreads);
EMDCluster1d emd_1d_sorted(Eigen::ArrayXd u_weights, Eigen::ArrayXd v_weights,
Expand Down
23 changes: 2 additions & 21 deletions ot/pipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "interp2d.h"
#include <memory>

using namespace ot;

struct NormalDistCrate {
ot::RowMajorMatrixXd norm;
double mu;
Expand All @@ -27,27 +29,6 @@ double* load_hdf5(std::string data_path, std::string dataset_name) {
return data;
}

vector<pair<int, int>> where(const ot::RowMajorMatrixXd& M, double thr) {
vector<pair<int, int>> indices;

for (int i = 0; i < M.rows(); ++i) {
for (int j = 0; j < M.cols(); ++j) {
if (M.coeffRef(i, j) <= thr) {
indices.emplace_back(i, j);
}
}
}

return indices;
}

void indexing_op(ot::RowMajorMatrixXd& M, const vector<pair<int, int>>& indices,
double v) {
for (const auto& index : indices) {
M(index.first, index.second) = v;
}
}

void free_hdf5(double* data) { delete[] data; }

ot::RowMajorMatrixXd distribution_normalize(const ot::RowMajorMatrixXd& M) {
Expand Down
3 changes: 1 addition & 2 deletions ot/pipe.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ using std::vector, std::pair;
Eigen::MatrixXd load_hdf5_to_eigen_col_major(std::string data_path,
std::string dataset_name);
ot::RowMajorMatrixXd load_hdf5_to_eigen_row_major(std::string data_path,
std::string dataset_name);
vector<pair<int, int>> where(const ot::RowMajorMatrixXd& M, double thr);
std::string dataset_name);

0 comments on commit b7d9645

Please sign in to comment.