forked from PythonOT/POT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweak.py
130 lines (102 loc) · 3.7 KB
/
weak.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Weak optimal ransport solvers
"""
# Author: Remi Flamary <[email protected]>
#
# License: MIT License
from .backend import get_backend
from .optim import cg
import numpy as np
__all__ = ["weak_optimal_transport"]
def weak_optimal_transport(
Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs
):
r"""Solves the weak optimal transport problem between two empirical distributions
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \sum_i \mathbf{a}_i \left(\mathbf{X^a}_i - \frac{1}{\mathbf{a}_i} \sum_j \gamma_{ij} \mathbf{X^b}_j \right)^2
s.t. \ \gamma \mathbf{1} = \mathbf{a}
\gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
where :
- :math:`X^a` and :math:`X^b` are the sample matrices.
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
Uses the conditional gradient algorithm to solve the problem proposed
in :ref:`[39] <references-weak>`.
Parameters
----------
Xa : (ns,d) array-like, float
Source samples
Xb : (nt,d) array-like, float
Target samples
a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
G0 : (ns,nt) array-like, float
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
gamma: array-like, shape (ns, nt)
Optimal transportation matrix for the given
parameters
log: dict, optional
If input log is true, a dictionary containing the
cost and dual variables and exit status
.. _references-weak:
References
----------
.. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
Kantorovich duality for general transport costs and applications.
Journal of Functional Analysis, 273(11), 3327-3405.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT
"""
nx = get_backend(Xa, Xb)
Xa2 = nx.to_numpy(Xa)
Xb2 = nx.to_numpy(Xb)
if a is None:
a2 = np.ones((Xa.shape[0])) / Xa.shape[0]
else:
a2 = nx.to_numpy(a)
if b is None:
b2 = np.ones((Xb.shape[0])) / Xb.shape[0]
else:
b2 = nx.to_numpy(b)
# init uniform
if G0 is None:
T0 = a2[:, None] * b2[None, :]
else:
T0 = nx.to_numpy(G0)
# weak OT loss
def f(T):
return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None]) ** 2, 1))
# weak OT gradient
def df(T):
return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T)
# solve with conditional gradient and return solution
if log:
res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs)
log["u"] = nx.from_numpy(log["u"], type_as=Xa)
log["v"] = nx.from_numpy(log["v"], type_as=Xb)
return nx.from_numpy(res, type_as=Xa), log
else:
return nx.from_numpy(
cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa
)