Skip to content

Commit

Permalink
Patch CDFt to add an option to control the support
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobink committed Dec 8, 2022
1 parent b681a12 commit e297f9e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
34 changes: 25 additions & 9 deletions SBCK/__CDFt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__( self , **kwargs ):##{{{
self._normalize_cdf = kwargs.get("normalize_cdf")
if ~(type(self._normalize_cdf) in [bool,list]):
self._normalize_cdf = True
self._p_left = 0
self._p_right = 1
##}}}

def fit( self , Y0 , X0 , X1 ):##{{{
Expand Down Expand Up @@ -233,7 +235,20 @@ def predict( self , X1 , X0 = None ):##{{{
return Z1
##}}}

def _infer_Y1( self , Y0 , X0 , X1 , idist ):##{{{
def _infer_Y1( self , Y0_ , X0_ , X1_ , idist ):##{{{

Y0 = Y0_
X0 = X0_
X1 = X1_
# Y0 = np.sort(Y0_.squeeze())[5:-5]
# X0 = np.sort(X0_.squeeze())[5:-5]
# X1 = np.sort(X1_.squeeze())[5:-5]
# qY0 = np.quantile( Y0_ , [0.05,0.95] ).squeeze()
# qX0 = np.quantile( X0_ , [0.05,0.95] ).squeeze()
# qX1 = np.quantile( X1_ , [0.05,0.95] ).squeeze()
# Y0 = Y0_[ (Y0_ > qY0[0]) & (Y0_ < qY0[1]) ]
# X0 = X0_[ (X0_ > qX0[0]) & (X0_ < qX0[1]) ]
# X1 = X1_[ (X1_ > qX1[0]) & (X1_ < qX1[1]) ]

dsupp = self._dsupp

Expand Down Expand Up @@ -351,16 +366,16 @@ def support_test( rv , x ):
icdfY1 = sci.interp1d( cdfY1 , x , fill_value = (x[0],x[-1]) , bounds_error = False )

## Now find cut
lsuppl_Y0 = np.median(Y0) - np.min(Y0)
lsuppl_X0 = np.median(X0) - np.min(X0)
lsuppl_X1 = np.median(X1) - np.min(X1)
lsuppl_Y0 = np.median(Y0) - np.quantile(Y0,self._p_left)
lsuppl_X0 = np.median(X0) - np.quantile(X0,self._p_left)
lsuppl_X1 = np.median(X1) - np.quantile(X1,self._p_left)
lsuppl_Y1 = lsuppl_Y0 * lsuppl_X1 / lsuppl_X0
lsuppl_pY1 = icdfY1(0.5) - icdfY1(0)
lsuppr_Y0 = np.max(Y0) - np.median(Y0)
lsuppr_X0 = np.max(X0) - np.median(X0)
lsuppr_X1 = np.max(X1) - np.median(X1)
lsuppl_pY1 = icdfY1(0.5) - icdfY1(self._p_left)
lsuppr_Y0 = np.quantile(Y0,self._p_right) - np.median(Y0)
lsuppr_X0 = np.quantile(X0,self._p_right) - np.median(X0)
lsuppr_X1 = np.quantile(X1,self._p_right) - np.median(X1)
lsuppr_Y1 = lsuppr_Y0 * lsuppr_X1 / lsuppr_X0
lsuppr_pY1 = icdfY1(1) - icdfY1(0.5)
lsuppr_pY1 = icdfY1(self._p_right) - icdfY1(0.5)

if lsuppl_pY1 > lsuppl_Y1 or lsuppr_pY1 > lsuppr_Y1:

Expand Down Expand Up @@ -405,6 +420,7 @@ def support_test( rv , x ):
# hY1 = icdfY1( np.random.uniform( size = self._samples_Y1 , low = p_min , high = p_max ) )
rvX1 = self._distX1.dist[idist]( *self._distX1.dist[idist].fit( X1.squeeze()) , **self._distX1.kwargs )
hY1 = icdfY1( rvX1.cdf(X1) )
# hY1 = icdfY1( rvX1.cdf(X1_) )

return hY1
##}}}
Expand Down
2 changes: 1 addition & 1 deletion SBCK/__release.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
version_major = 0
version_minor = 5
version_patch = 0
version_extra = "a24"
version_extra = "a25"
version = "{}.{}.{}{}".format(version_major,version_minor,version_patch,version_extra)

name = "SBCK"
Expand Down

0 comments on commit e297f9e

Please sign in to comment.