Skip to content

Commit

Permalink
minor figure plotting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
TomGeorge1234 committed Aug 6, 2022
1 parent 068cefc commit c30618a
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def plot_rate_map(
fig=None,
ax=None,
shape=None,
colorbar=True,
t_start=0,
t_end=None,
**kwargs,
Expand All @@ -219,7 +220,7 @@ def plot_rate_map(
• fig, ax (the fig and ax to draw on top of, optional)
• shape is the shape of the multiplanlle figure, must be compatible with chosen neurons
• colorbar: whether to show a colorbar
• t_start, t_end: i nthe case where you are plotting spike, or using historical data to get rate map, this restricts the timerange of data you are using
• kwargs are sent to get_state and can be ignore if you don't need to use them
Expand Down Expand Up @@ -272,26 +273,33 @@ def plot_rate_map(
Nx, Ny = 1, len(chosen_neurons)
else:
Nx, Ny = shape[0], shape[1]
fig = plt.figure()
fig = plt.figure(figsize=(2 * Ny, 2 * Nx))
if colorbar == True:
cbar_mode = "single"
else:
cbar_mode = None
axes = ImageGrid(
fig,
# (0, 0, 3, 3),
111,
nrows_ncols=(Nx, Ny),
axes_pad=0.05,
cbar_location="right",
cbar_mode="single",
cbar_mode=cbar_mode,
cbar_size="5%",
cbar_pad=0.05,
)
cax = axes.cbar_axes[0]
if colorbar == True:
cax = axes.cbar_axes[0]
axes = np.array(axes)
else:
axes = np.array([ax]).reshape(-1)
if method in ["groundtruth", "history"]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
if colorbar == True:
from mpl_toolkits.axes_grid1 import make_axes_locatable

divider = make_axes_locatable(axes[-1])
cax = divider.append_axes("right", size="5%", pad=0.05)
divider = make_axes_locatable(axes[-1])
cax = divider.append_axes("right", size="5%", pad=0.05)
for (i, ax_) in enumerate(axes):
self.Agent.Environment.plot_environment(fig, ax_)
if len(chosen_neurons) != axes.size:
Expand Down Expand Up @@ -326,11 +334,12 @@ def plot_rate_map(
)
for im in ims:
im.set_clim((vmin, vmax))
cbar = plt.colorbar(ims[-1], cax=cax)
lim_v = vmax if vmax > -vmin else vmin
cbar.set_ticks([0, lim_v])
cbar.set_ticklabels([0, round(lim_v, 0)])
cbar.outline.set_visible(False)
if colorbar == True:
cbar = plt.colorbar(ims[-1], cax=cax)
lim_v = vmax if vmax > -vmin else vmin
cbar.set_ticks([0, lim_v])
cbar.set_ticklabels([0, round(lim_v, 1)])
cbar.outline.set_visible(False)

if spikes is True:
for (i, ax_) in enumerate(axes):
Expand Down

0 comments on commit c30618a

Please sign in to comment.