Skip to content

Commit

Permalink
[MRG] Correct pointer overflow in EMD (PythonOT#381)
Browse files Browse the repository at this point in the history
* avoid overflow on openmp version of emd solver

* monothread version updated

* Fixed typo in readme

* added PR in releases

* typo in releases.md

* added a precision to releases.md

* added a precision to releases.md

* correct readme

* forgot to cast

* lower error
  • Loading branch information
ncassereau authored Jun 13, 2022
1 parent 1f30759 commit e547fe3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ POT provides the following generic OT solvers (links to examples):
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
Expand Down
5 changes: 4 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)
- Fixed an issue where the max number of iterations in ot.emd was not allowed to go beyond 2^31 (PR #380)
- Fixed an issue where pointers would overflow in the EMD solver, returning an
incomplete transport plan above a certain size (slightly above 46k, its square being
roughly 2^31) (PR #381)


## 0.8.2
Expand Down
36 changes: 18 additions & 18 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// beware M and C are stored in row major C style!!!

using namespace lemon;
int n, m, cur;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand All @@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,

// Define the graph

std::vector<int> indI(n), indJ(m);
std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (int i=0; i<n1; i++) {
for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
Expand All @@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...

cur=0;
for (int i=0; i<n2; i++) {
for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
Expand All @@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}


net.supplyMap(&weights1[0], n, &weights2[0], m);
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
for (uint64_t i=0; i<n; i++) {
for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
Expand All @@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
Expand Down Expand Up @@ -126,7 +126,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// beware M and C are stored in row major C style!!!

using namespace lemon_omp;
int n, m, cur;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand All @@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,

// Define the graph

std::vector<int> indI(n), indJ(m);
std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (int i=0; i<n1; i++) {
for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
Expand All @@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...

cur=0;
for (int i=0; i<n2; i++) {
for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
Expand All @@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
}


net.supplyMap(&weights1[0], n, &weights2[0], m);
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
for (uint64_t i=0; i<n; i++) {
for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
Expand All @@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
Expand Down
4 changes: 2 additions & 2 deletions ot/lp/network_simplex_simple_omp.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
#undef EPSILON
#undef _EPSILON
#undef MAX_DEBUG_ITER
#define EPSILON std::numeric_limits<Cost>::epsilon()*10
#define _EPSILON 1e-8
#define EPSILON std::numeric_limits<Cost>::epsilon()
#define _EPSILON 1e-14
#define MAX_DEBUG_ITER 100000

/// \ingroup min_cost_flow_algs
Expand Down

0 comments on commit e547fe3

Please sign in to comment.