Skip to content

Commit

Permalink
Auto PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Apr 2, 2020
1 parent 1e2e118 commit 60943d0
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

import multiprocessing
import sys

import numpy as np
from scipy.sparse import coo_matrix

from .import cvx

from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
from ..utils import parmap

__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
Expand Down Expand Up @@ -458,7 +458,8 @@ def f(b):
return res


def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
Expand Down Expand Up @@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None

T_sum = np.zeros((k, d))

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
Expand Down Expand Up @@ -651,8 +652,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]

x_a_1d = x_a.reshape((-1, ))
x_b_1d = x_b.reshape((-1, ))
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)

Expand Down

0 comments on commit 60943d0

Please sign in to comment.