Skip to content

Commit

Permalink
Update graphs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ellariel authored Oct 17, 2024
1 parent 5d4073b commit c0ccd4a
Showing 1 changed file with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion metatools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
import os



def merge_2axes(fig1, fig2, file_name1="f1_.png", file_name2="f2_.png", horizontal=True, dpi=600):
fig1.savefig(file_name1, dpi=dpi, bbox_inches='tight', pad_inches=0.5)
fig2.savefig(file_name2, dpi=dpi, bbox_inches='tight', pad_inches=0.5)
Expand All @@ -35,6 +37,7 @@ def merge_2axes(fig1, fig2, file_name1="f1_.png", file_name2="f2_.png", horizont
os.remove(file_name2)
return fig


def radar_factory(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Expand All @@ -47,7 +50,7 @@ def radar_factory(num_vars, frame='circle'):
"""
# calculate evenly-spaced axis angles
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)

class RadarTransform(PolarAxes.PolarTransform):
def transform_path_non_affine(self, path):
# Paths with non-unit interpolation steps correspond to gridlines,
Expand Down Expand Up @@ -124,6 +127,7 @@ def _gen_axes_spines(self):
register_projection(RadarAxes)
return theta


def ols_tree_graph(r, title, use_rlm=False, forecolor='mediumorchid', backcolor='thistle'):
def addlabels(x, y, sig):
for i, j in enumerate(x):
Expand Down Expand Up @@ -172,3 +176,83 @@ def addlabels(x, y, sig):
fig.subplots_adjust(wspace=0, top=0.93)
#plt.show()
return fig


def lm_tree_graph(results, file_name='fig1.png', report='stars', #p ci
title=None, exclude=[], rename_dict={}, dpi=600, x_label='$β$',
forecolor='mediumorchid', backcolor='thistle', sort=True,
coef='coef', p_value='p-value', CIL='CIL', CIR='CIR', z='z', sig='sig'):

def addlabels(x, y, rep):
for i, j in enumerate(x):
if j < 0.0:
plt.text(j-0.04, y[i]-0.04, rep[i], ha='right')
else:
plt.text(j+0.04, y[i]-0.04, rep[i], ha='left')

model = results.model.iloc[0]
results = results.copy()
if sort:
results = results.sort_values(by=z, key=lambda x: abs(x))
for e in exclude:
if e in results.index:
results.drop(index=e, axis=0, inplace=True)
results.index = pd.Series(results.index).apply(lambda x: rename_dict[x] if x in rename_dict else x)

x1 = results[results[coef] < 0.0][coef].to_list()
sig1 = results[results[coef] < 0.0][sig].to_list()
x2 = results[results[coef] > 0.0][coef].to_list()
sig2 = results[results[coef] > 0.0][sig].to_list()
if report == 'p':
l1 = results[results[coef] < 0.0][p_value].to_list()
l1 = ['p='+i if i[0]=='.' else 'p'+i for i in l1]
l2 = results[results[coef] > 0.0][p_value].to_list()
l2 = ['p='+i if i[0]=='.' else 'p'+i for i in l2]
elif report == 'ci':
ll1 = results[results[coef] < 0.0][CIL].to_list()
rl1 = results[results[coef] < 0.0][CIR].to_list()
l1 = [f'[{i}, {j}]' for i, j in zip(ll1, rl1)]
ll2 = results[results[coef] > 0.0][CIL].to_list()
rl2 = results[results[coef] > 0.0][CIR].to_list()
l2 = [f'[{i}, {j}]' for i, j in zip(ll2, rl2)]
else:
l1 = sig1
l2 = sig2
y1 = range(len(x1))
y2 = range(len(x1), (len(x1) + len(x2)))
y = range(len(x1) + len(x2))
yl1 = results[results[coef] < 0.0].index.to_list()
yl2 = results[results[coef] > 0.0].index.to_list()
fig, axes = plt.subplots(ncols=2, sharey=True, figsize=(5, 3.5))
bar1 = axes[0].barh(y1, x1, align='center', color='red')
bar2 = axes[1].barh(y2, x2, align='center', color='blue')
axes[0].set(yticks=y, yticklabels=yl1 + yl2)
axes[0].set_xlim([-1, -0.001])
axes[1].set_xlim([0.001, 1])
for ax in axes.flat:
ax.margins(0.03)
ax.grid(True)
for i, b in enumerate(bar1):
if len(sig1[i]):
b.set_color(forecolor)
else:
b.set_color(backcolor)
for i, b in enumerate(bar2):
if len(sig2[i]):
b.set_color(forecolor)
else:
b.set_color(backcolor)
x_left, x_right = plt.xlim()
y_bottom, y_top = plt.ylim()
addlabels(x1, y1, l1)
addlabels(x2, y2, l2)
axes[0].tick_params(axis='y', which='both', labelleft=True, labelright=False)
axes[1].tick_params(axis='y', which='both', labelleft=False, labelright=False)
if title==None:
title=model
fig.suptitle(title, y=len(title.split('\n'))*.01+1.0, x=(x_right+x_left)/2+.009)
fig.text((x_right+x_left)/2+.009, 0, x_label, ha='center')
fig.subplots_adjust(wspace=0, top=0.93)
plt.savefig(file_name, dpi=dpi, bbox_inches='tight')
plt.close()
return fig

0 comments on commit c0ccd4a

Please sign in to comment.