Skip to content

Commit

Permalink
Version 1.4.0, major improvement of CDFt
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobink committed Oct 11, 2023
1 parent 1baaa8d commit 4f2f165
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
85 changes: 62 additions & 23 deletions SBCK/__CDFt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

## Copyright(c) 2021 Yoann Robin
## Copyright(c) 2021 / 2023 Yoann Robin
##
## This file is part of SBCK.
##
Expand Down Expand Up @@ -97,6 +97,8 @@ def __init__( self , **kwargs ):##{{{
if X0 and Y0 are not None.
tol : float
Numerical tolerance, default 1e-6
version: int, optional
...
scale_left_tail: float, optional
Scale applied on the left support (min to median) between
calibration and projection period. If None (default), it is
Expand All @@ -115,9 +117,11 @@ def __init__( self , **kwargs ):##{{{
"""
self.n_features = kwargs.get("n_features")
self._tol = kwargs.get("tol") if kwargs.get("tol") is not None else 1e-6
self._dsupp = kwargs.get("dsupp") if kwargs.get("dsupp") is not None else 1000
self._samples_Y1 = kwargs.get("samples_Y1") if kwargs.get("samples_Y1") is not None else 10000
self._tol = kwargs.get( "tol" , 1e-6 )
self._dsupp = kwargs.get( "dsupp" , 1000 )
self._samples_Y1 = kwargs.get("samples_Y1" , 10000 )
self._version = kwargs.get("version" , 3 )
self._v3_e = kwargs.get("v3_e" , "auto" )

self._distY0 = _Dist( dist = kwargs.get("distY0") , kwargs = kwargs.get("kwargsY0") )
self._distY1 = _Dist( dist = kwargs.get("distY1") , kwargs = kwargs.get("kwargsY1") )
Expand Down Expand Up @@ -235,20 +239,19 @@ def predict( self , X1 , X0 = None ):##{{{
return Z1
##}}}

def _infer_Y1( self , Y0_ , X0_ , X1_ , idist ):##{{{

Y0 = Y0_
X0 = X0_
X1 = X1_
# Y0 = np.sort(Y0_.squeeze())[5:-5]
# X0 = np.sort(X0_.squeeze())[5:-5]
# X1 = np.sort(X1_.squeeze())[5:-5]
# qY0 = np.quantile( Y0_ , [0.05,0.95] ).squeeze()
# qX0 = np.quantile( X0_ , [0.05,0.95] ).squeeze()
# qX1 = np.quantile( X1_ , [0.05,0.95] ).squeeze()
# Y0 = Y0_[ (Y0_ > qY0[0]) & (Y0_ < qY0[1]) ]
# X0 = X0_[ (X0_ > qX0[0]) & (X0_ < qX0[1]) ]
# X1 = X1_[ (X1_ > qX1[0]) & (X1_ < qX1[1]) ]
def _CDFt_V1( self , Y0 , X0 , X1 , idist ):##{{{

## CDF
rvY0 = self._distY0.law[idist]
rvX0 = self._distX0.dist[idist]( *self._distX0.dist[idist].fit( X0.squeeze()) , **self._distX0.kwargs )
rvX1 = self._distX1.dist[idist]( *self._distX1.dist[idist].fit( X1.squeeze()) , **self._distX1.kwargs )

hY1 = rvX1.ppf( rvX0.cdf( rvY0.ppf( rvX1.cdf( X1 ) ) ) )

return hY1
##}}}

def _CDFt_V2( self , Y0 , X0 , X1 , idist ):##{{{

dsupp = self._dsupp

Expand Down Expand Up @@ -419,13 +422,49 @@ def support_test( rv , x ):
# print(cdfY1[-1])

## Draw hY1
hY1 = icdfY1( np.random.uniform( size = self._samples_Y1 , low = 0 , high = 1 ) )
# hY1 = icdfY1( np.random.uniform( size = self._samples_Y1 , low = p_min , high = p_max ) )
# rvX1 = self._distX1.dist[idist]( *self._distX1.dist[idist].fit( X1.squeeze()) , **self._distX1.kwargs )
# hY1 = icdfY1( rvX1.cdf(X1) )
rvX1 = self._distX1.dist[idist]( *self._distX1.dist[idist].fit( X1.squeeze()) , **self._distX1.kwargs )
hY1 = icdfY1( rvX1.cdf(X1) )

return hY1
##}}}

def _CDFt_V3( self , Y0 , X0 , X1 , idist ):##{{{

if (X0.min() <= Y0.min()) and (X0.max() >= Y0.max()):
return self._CDFt_V1( Y0 , X0 , X1 , idist )

if not type(self._v3_e) is float:
self._v3_e = 5 * max( 1 / Y0.size , 1 / X0.size , 1 / X1.size )
e = self._v3_e

lX0 = np.quantile( X0 , e )
lY0 = np.quantile( Y0 , e )
lX1 = np.quantile( X1 , e )

uX0 = np.quantile( X0 , 1 - e )
uY0 = np.quantile( Y0 , 1 - e )
uX1 = np.quantile( X1 , 1 - e )

X0s = ( X0 - lX0 ) / ( uX0 - lX0 ) * ( uY0 - lY0 ) + lY0
X1s = ( X1 - lX1 ) / ( uX0 - lX0 ) * ( uY0 - lY0 ) + lY0 + lX1 - lX0

rvY0 = self._distY0.law[idist]
rvX0s = self._distX0.dist[idist]( *self._distX0.dist[idist].fit( X0s.squeeze()) , **self._distX0.kwargs )
rvX1s = self._distX1.dist[idist]( *self._distX1.dist[idist].fit( X1s.squeeze()) , **self._distX1.kwargs )

hY1 = rvX1s.ppf( rvX0s.cdf( rvY0.ppf( rvX1s.cdf( X1s ) ) ) )

return hY1
##}}}


def _infer_Y1( self , Y0 , X0 , X1 , idist ):##{{{

if self._version == 1:
return self._CDFt_V1( Y0 , X0 , X1 , idist )
elif self._version == 2:
return self._CDFt_V2( Y0 , X0 , X1 , idist )
else:
return self._CDFt_V3( Y0 , X0 , X1 , idist )
##}}}


4 changes: 2 additions & 2 deletions SBCK/__release.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


version_major = 1
version_minor = 3
version_patch = 2
version_minor = 4
version_patch = 0
version_extra = ""
version = "{}.{}.{}{}".format(version_major,version_minor,version_patch,version_extra)

Expand Down

0 comments on commit 4f2f165

Please sign in to comment.