Skip to content

Commit

Permalink
Remove NetworkSimplex c++ code and add pot package as solver
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobink committed Sep 6, 2023
1 parent b9843f4 commit c00593d
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 39 deletions.
15 changes: 6 additions & 9 deletions SBCK/__OTC.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand All @@ -25,8 +25,9 @@
import scipy.stats as sc
from .tools.__tools_cpp import SparseHist
from .tools.__bin_width_estimator import bin_width_estimator
from .tools.__OT import OTNetworkSimplex
from .tools.__OT import OTSinkhornLogDual
from .tools.__OT import POTemd
#from .tools.__OT import OTNetworkSimplex
#from .tools.__OT import OTSinkhornLogDual


###########
Expand All @@ -47,7 +48,7 @@ class OTC:
[1] Robin, Y., Vrac, M., Naveau, P., Yiou, P.: Multivariate stochastic bias corrections with optimal transport, Hydrol. Earth Syst. Sci., 23, 773–786, 2019, https://doi.org/10.5194/hess-23-773-2019
"""

def __init__( self , bin_width = None , bin_origin = None , ot = OTNetworkSimplex() ):##{{{
def __init__( self , bin_width = None , bin_origin = None , ot = POTemd() ):##{{{
"""
Initialisation of Optimal Transport bias Corrector.
Expand All @@ -58,7 +59,7 @@ def __init__( self , bin_width = None , bin_origin = None , ot = OTNetworkSimple
bin_origin : np.array( [shape = (n_features) ] )
Corner of one bin, see SBCK.SparseHist. If is None, np.repeat(0,n_features) is used
ot : OT*Solver*
A solver for Optimal transport, default is OTSinkhornLogDual()
A solver for Optimal transport, default is POTemd()
Attributes
----------
Expand Down Expand Up @@ -101,10 +102,6 @@ def fit( self , Y0 , X0 ):##{{{

## Optimal Transport
self._ot.fit( self.muX , self.muY )
if not self._ot.state:
print( "Warning: Error in network simplex, try SinkhornLogDual" )
self._ot = OTSinkhornLogDual()
self._ot.fit( self.muX , self.muY )

##
self._plan = np.copy( self._ot.plan() )
Expand Down
6 changes: 3 additions & 3 deletions SBCK/__dOTC.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand All @@ -26,7 +26,7 @@

from .tools.__tools_cpp import SparseHist
from .tools.__bin_width_estimator import bin_width_estimator
from .tools.__OT import OTNetworkSimplex
from .tools.__OT import POTemd
from .__OTC import OTC


Expand Down Expand Up @@ -72,7 +72,7 @@ def _eps_cholesky( self , M , nit = 200 ): #{{{
return MC
#}}}

def __init__( self , bin_width = None , bin_origin = None , cov_factor = "std" , ot = OTNetworkSimplex() ):##{{{
def __init__( self , bin_width = None , bin_origin = None , cov_factor = "std" , ot = POTemd() ):##{{{
"""
Initialisation of dynamical Optimal Transport bias Corrector.
Expand Down
8 changes: 4 additions & 4 deletions SBCK/metrics/__wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand All @@ -22,7 +22,7 @@
###############

import numpy as np
from SBCK.tools.__OT import OTNetworkSimplex
from SBCK.tools.__OT import POTemd
from .__decorators import _to_SparseHist


Expand All @@ -31,7 +31,7 @@
##############

@_to_SparseHist
def wasserstein( muX , muY , p = 2. , ot = OTNetworkSimplex() , metric = "euclidean" ):
def wasserstein( muX , muY , p = 2. , ot = POTemd() , metric = "euclidean" ):
"""
Description
===========
Expand Down Expand Up @@ -63,7 +63,7 @@ def wasserstein( muX , muY , p = 2. , ot = OTNetworkSimplex() , metric = "euclid

ot.power = p
ot.fit( muX , muY )
if type(ot) == OTNetworkSimplex:
if type(ot) == POTemd:
w = cost(ot)
if not ot.state:
w = np.nan
Expand Down
39 changes: 25 additions & 14 deletions SBCK/tools/__OT.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand All @@ -23,8 +23,8 @@

import numpy as np
import scipy.spatial.distance as ssd
from .__tools_cpp import network_simplex

#from .__tools_cpp import network_simplex
import ot

#############
## Classes ##
Expand All @@ -36,19 +36,17 @@ def __init__( self , c , p ):
self.c = c
##}}}

class OTNetworkSimplex:##{{{
class POTemd:##{{{
"""
SBCK.tools.NetworkSimplex
=========================
SBCK.tools.POTemd
=================
Network simplex method to solve optimal transport problem
Earth Mover Distance (emd) solver from the POT package.
see https://pythonot.github.io
References
==========
Bazaraa, M. S., Jarvis, J. J., and Sherali, H. D.: Linear Programming and Network Flows, 4th edn., John Wiley & Sons, 2009.
"""

def __init__( self , power = 2 ):##{{{
def __init__( self , power = 2 , numItermax = 100_000_000):##{{{
"""
Initialisation of solver
Expand All @@ -60,6 +58,9 @@ def __init__( self , power = 2 ):##{{{
self.C = None
self.P = None
self.power = power
self.state = True
self.numItermax = numItermax

##}}}

def fit( self , mu0 , mu1 , C = None ):##{{{
Expand All @@ -73,10 +74,18 @@ def fit( self , mu0 , mu1 , C = None ):##{{{
mu1 : (SBCK.SparseHist)
Target histogram
"""
self.C = ssd.cdist( mu0.c , mu1.c )**self.power if C is None else C

self.P,self.state = network_simplex( mu0.p , mu1.p , self.C )
self.C = C
if C is None:
self.C = ssd.cdist( mu0.c , mu1.c )
self.C = self.C**self.power

try:
_,out = ot.emd2( mu0.p , mu1.p , self.C , return_matrix = True , numItermax = self.numItermax )
self.P = out['G']
self.state = True
except Exception as e:
print(e)
self.state = False
##}}}

def plan(self):##{{{
Expand All @@ -90,8 +99,10 @@ def plan(self):##{{{
"""
return self.P
##}}}

##}}}


class OTSinkhorn:##{{{
"""
SBCK.tools.OTSinkhorn
Expand Down
4 changes: 2 additions & 2 deletions SBCK/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand All @@ -20,7 +20,7 @@

from .__tools_cpp import SparseHist
from .__bin_width_estimator import bin_width_estimator
from .__OT import OTNetworkSimplex
from .__OT import POTemd
from .__OT import OTSinkhorn
from .__OT import OTSinkhornLogDual
from .__shuffle import schaake_shuffle
Expand Down
5 changes: 1 addition & 4 deletions SBCK/tools/src/tools.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

// Copyright(c) 2021 Yoann Robin
// Copyright(c) 2021 / 2023 Yoann Robin
//
// This file is part of SBCK.
//
Expand All @@ -25,7 +25,6 @@
#include <pybind11/eigen.h>

#include "SparseHist.hpp"
#include "NetworkSimplex.hpp"

//============//
// namespaces //
Expand All @@ -44,8 +43,6 @@ PYBIND11_MODULE( __tools_cpp , m )
// Functions //
//===========//

m.def( "network_simplex" , &network_simplex ) ;

//=======//
// Class //
//=======//
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def build_extensions(self):
language='c++',
depends = [
"SBCK/tools/src/SparseHist.hpp"
"SBCK/tools/src/NetworkSimplex.hpp"
"SBCK/tools/src/NetworkSimplexLemon.hpp"
]
),
]
Expand Down Expand Up @@ -223,7 +221,7 @@ def build_extensions(self):
"Topic :: Scientific/Engineering :: Mathematics"
],
ext_modules = ext_modules,
install_requires = [ "numpy" , "scipy" , "matplotlib" , "pybind11>=2.2" ],
install_requires = [ "numpy" , "scipy" , "matplotlib" , "pybind11>=2.2" , "pot>=0.9.0"],
cmdclass = {'build_ext': BuildExt},
zip_safe = False,
packages = list_packages,
Expand Down

0 comments on commit c00593d

Please sign in to comment.