Skip to content

Commit

Permalink
added jitter option for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
okasag committed Aug 28, 2022
1 parent 6957045 commit 1f96fab
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 28 deletions.
34 changes: 22 additions & 12 deletions samplefit/Reliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,11 @@ def plot(self,
figsize=None,
s=None,
ylim=None,
xlim=None,
xlabel=None,
dpi=None,
fname=None):
fname=None,
jitter=False):
"""
Plot the Reliability Scores based on the RSR algorithm.
Expand Down Expand Up @@ -827,23 +829,29 @@ def plot(self,
Default is automatic.
ylim : tuple, list or NoneType
Tuple of upper and lower limits of y axis. Default is automatic.
xlim : tuple, list or NoneType
Tuple of upper and lower limits of x axis. Default is automatic.
xlabel : str or NoneType
Label for the x axis for the exog variable. Default is 'xname'.
dpi : float, int or NoneType
The resolution for matplotlib scatter plot. Default is 100.
fname : str or NoneType
Valid figure name to save the plot. If None, generic name is used.
Default is None.
jitter : bool
Logical, if scatterplot should be jittered for categorical
features. Note, that this involves random perturbation of the
values of features along X axis, fixing seed is thus necessary for
reproducibility. Default is False.
Returns
-------
Dictionary of matplotlib figures. Prints annealing plots.
Dictionary of matplotlib figures. Prints scoring plots.
Notes
-----
[`.plot()`](#samplefit.Reliability.RSRAnnealResults.plot) produces
an annealing plot for assessment of sample fit sensitivity, together
with parameters and confidence intervals.
[`.plot()`](#samplefit.Reliability.RSRScoreResults.plot) produces
a scoring plot for assessment of sample fit reliability.
Examples
--------
Expand All @@ -864,14 +872,14 @@ def plot(self,
# specify sample
sample = sf.RSR(linear_model=model)
# sample annealing
sample_annealing = sample.anneal()
# sample reliability
sample_scores = sample.score()
# default annealing plot
sample_annealing.plot()
# default scoring plot
sample_scores.plot()
# custom annealing
sample_annealing.plot(title='My Title')
# custom scoring
sample_scores.plot(title='My Title')
```
"""
return super().plot(
Expand All @@ -883,7 +891,9 @@ def plot(self,
figsize=figsize,
s=s,
ylim=ylim,
xlim=xlim,
xlabel=xlabel,
dpi=dpi,
fname=fname
fname=fname,
jitter=jitter
)
21 changes: 16 additions & 5 deletions samplefit/_BaseReliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,27 @@ def _input_checks(self):
# if float, then take it as a share
if isinstance(min_samples, float):
# check if its within (0,1]
if (min_samples > 0 and min_samples <= 1):
# assign the input value
self.min_samples = int(np.ceil(min_samples * self.n_obs))
if (min_samples > 0 and min_samples < 1):
# get the integer value for sampling size
min_samples_int = int(np.ceil(min_samples * self.n_obs))
# check if its within [p,N-1]
if ((min_samples_int >= self.n_exog + 1) and
(min_samples_int <= self.n_obs - 1)):
# assign the input value
self.min_samples = min_samples_int
else:
# raise value error
raise ValueError("min_samples must be within [p+1,N-1]"
", increase the min_samples share or "
"specify number of minimum samples "
"directly as an integer.")
else:
# raise value error
raise ValueError("min_samples must be within (0,1]"
raise ValueError("min_samples must be within (0,1)"
", got %s" % min_samples)
# if int, then take it as a absolute number
else:
# check if its within [p,N-1]
# check if its within [p+1,N-1]
if ((min_samples >= self.n_exog + 1) and
(min_samples <= self.n_obs - 1)):
# assign the input value
Expand Down
75 changes: 64 additions & 11 deletions samplefit/_BaseResultsReliability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,15 +1125,15 @@ def _output_checks(self):
# %% in-class functions
# function to plot reliability scores
def plot(self, yname=None, xname=None, title=None, cmap=None, path=None,
figsize=None, s=None, ylim=None, xlabel=None, dpi=None,
fname=None):
figsize=None, s=None, ylim=None, xlim=None, xlabel=None, dpi=None,
fname=None, jitter=False):
"""
RSR Reliability Scores Plot.
"""

# check inputs for plot
self._check_plot_inputs(yname, xname, title, cmap, path, figsize, s,
ylim, xlabel, dpi, fname)
ylim, xlim, xlabel, dpi, fname, jitter)

# set resolution
plt.rcParams['figure.dpi'] = self.dpi
Expand All @@ -1151,19 +1151,50 @@ def plot(self, yname=None, xname=None, title=None, cmap=None, path=None,
# define the plot layout
fig, ax = plt.subplots(nrows=1, ncols=1, figsize = self.figsize)

# plot scores
plot = ax.scatter(x=self.exog[:, var_idx],
y=self.endog,
c=self.scores,
cmap=self.cmap,
s=self.s)
# add titles, ticks, etc.
# check if X is categorical
if np.sum(self.exog[:, var_idx].astype(int) -
self.exog[:, var_idx]) == 0:
# get distinct values
cat_values = list(np.sort(np.unique(self.exog[:, var_idx])
).astype(int))
# check if jitter should be applied
if self.jitter:
# apply jitter random noise for visualisation purposes
exog_jitter = (self.exog[:, var_idx].copy() +
np.random.uniform(-0.1, 0.1,
len(self.endog)))
else:
# keep original values
exog_jitter = self.exog[:, var_idx].copy()
# scatter plot
plot = ax.scatter(x=exog_jitter,
y=self.endog,
c=self.scores,
cmap=self.cmap,
s=self.s)
# plot ticks
cat_values.insert(0, (np.min(cat_values) - 0.5))
cat_values.append((np.max(cat_values) + 0.5))
ticks = cat_values.copy()
# add ticks
plt.xticks(ticks, cat_values)
else:
# scatter plot
plot = ax.scatter(x=self.exog[:, var_idx],
y=self.endog,
c=self.scores,
cmap=self.cmap,
s=self.s)

# add titles, labels, etc.
ax.title.set_text(self.title)
ax.set_xlabel(self.xlabel[iter_idx])
ax.set_ylabel(self.yname)
# set limits if specified
if not self.ylim is None:
ax.set_ylim(self.ylim)
if not self.xlim is None:
ax.set_xlim(self.xlim)
# add legend
legend = ax.legend(*plot.legend_elements(),
title="Reliability Score",
Expand Down Expand Up @@ -1206,7 +1237,7 @@ def plot(self, yname=None, xname=None, title=None, cmap=None, path=None,

# check inputs for score plot
def _check_plot_inputs(self, yname, xname, title, cmap, path, figsize, s,
ylim, xlabel, dpi, fname):
ylim, xlim, xlabel, dpi, fname, jitter):
"""Input checks for the .plot() function."""

# check name for y
Expand Down Expand Up @@ -1342,6 +1373,19 @@ def _check_plot_inputs(self, yname, xname, title, cmap, path, figsize, s,
raise ValueError("ylim must be a tuple or a list"
", got %s" % type(ylim))

# check name for xlim
if xlim is None:
# set default auto
self.xlim = xlim
# if supplied check if its valid
elif isinstance(xlim, (tuple, list)):
# set value to user supplied
self.xlim = xlim
else:
# raise value error
raise ValueError("xlim must be a tuple or a list"
", got %s" % type(xlim))

# check markersize s
if s is None:
# set default auto
Expand Down Expand Up @@ -1396,3 +1440,12 @@ def _check_plot_inputs(self, yname, xname, title, cmap, path, figsize, s,
# raise value error
raise ValueError("dpi must be float or int"
", got %s" % type(dpi))

# check jitter
if isinstance(jitter, bool):
# assign value
self.jitter = jitter
else:
# raise value error
raise ValueError("jitter must be boolean"
", got %s" % type(jitter))

0 comments on commit 1f96fab

Please sign in to comment.