forked from PythonOT/POT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_utils.py
41 lines (31 loc) · 1.24 KB
/
_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# -*- coding: utf-8 -*-
"""
Common tools of Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <[email protected]>
# Nicolas Courty <[email protected]>
#
# License: MIT License
from ..utils import list_to_array
from ..backend import get_backend
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
weights, alldistribT = list_to_array(weights, alldistribT)
nx = get_backend(weights, alldistribT)
assert len(weights) == alldistribT.shape[1]
return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
alldistribT = list_to_array(alldistribT)
nx = get_backend(alldistribT)
return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
def projR(gamma, p):
"""return the KL projection on the row constraints"""
gamma, p = list_to_array(gamma, p)
nx = get_backend(gamma, p)
return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
def projC(gamma, q):
"""return the KL projection on the column constraints"""
gamma, q = list_to_array(gamma, q)
nx = get_backend(gamma, q)
return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)