Skip to content

Commit

Permalink
MPLVisualization subclasses dict
Browse files Browse the repository at this point in the history
  • Loading branch information
charwick committed Jun 1, 2023
1 parent 5a3c870 commit 84f0924
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 125 deletions.
3 changes: 2 additions & 1 deletion helipad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from io import BufferedWriter

#Using _ is a disaster. Can't install to global scope because it conflicts with readline;
#Can't name it _ here because `import *` skips it (???)
#Can't name it _ here because `import *` skips it
def ï(text) -> str:
"""Internationalization. Named so as to avoid a conflict with `_` in the REPL console."""
return helipad_gettext(text)
Expand All @@ -23,6 +23,7 @@ def isIpy() -> bool:
def isNotebook() -> bool:
"""Check whether Helipad is running in an interactive notebook."""
#get_ipython() comes back undefined inside callbacks. So cache the value once, the first time it runs.
#Can try @functools.cache when Python 3.9 is required
if not '__helipad_ipy' in globals():
try:
globals()['__helipad_ipy'] = 'InteractiveShell' in get_ipython().__class__.__name__
Expand Down
16 changes: 8 additions & 8 deletions helipad/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def reporter(model): return param.get(item)

if self.goods.money is not None:
self.data.addReporter('M0', self.data.agentReporter('stocks', 'all', good=self.goods.money, stat='sum'))
if self.visual is not None and isinstance(self.visual, TimeSeries) and 'money' in self.visual.plots:
self.visual.plots['money'].addSeries('M0', ï('Monetary Base'), self.goods[self.goods.money].color)
if self.visual is not None and isinstance(self.visual, TimeSeries) and 'money' in self.visual:
self.visual['money'].addSeries('M0', ï('Monetary Base'), self.goods[self.goods.money].color)

#Unconditional variables to report
# self.data.addReporter('utility', self.data.agentReporter('utils', defPrim))
Expand All @@ -186,13 +186,13 @@ def reporter(model): return param.get(item)
for breed, b in self.agents[defPrim].breeds.items():
self.data.addReporter('utility-'+breed, self.data.agentReporter('utils', defPrim, breed=breed))
if self.visual is not None and self.visual.__class__.__name__=='TimeSeries':
self.visual.plots['utility'].addSeries('utility-'+breed, breed.title()+' '+ï('Utility'), b.color)
self.visual['utility'].addSeries('utility-'+breed, breed.title()+' '+ï('Utility'), b.color)

if len(self.goods) >= 2:
for good, g in self.goods.nonmonetary.items():
self.data.addReporter('demand-'+good, self.data.agentReporter('currentDemand', 'all', good=good, stat='sum'))
if self.visual is not None:
if 'demand' in self.visual.plots: self.visual.plots['demand'].addSeries('demand-'+good, good.title()+' '+ï('Demand'), g.color)
if 'demand' in self.visual: self.visual['demand'].addSeries('demand-'+good, good.title()+' '+ï('Demand'), g.color)

#Initialize agents
for prim, ags in self.agents.items():
Expand Down Expand Up @@ -316,12 +316,12 @@ async def run(self):
if self.timer: t2 = time.time()
if self.cpanel and st and isinstance(st, int): self.cpanel.progress.update(t/st)

#Update graph
#Update visualizations
if self.visual is not None and not self.visual.isNull:
await asyncio.sleep(0.001) #Listen for keyboard input
data = self.data.getLast(t - self.visual.lastUpdate)

self.visual.update(data)
self.visual.refresh(data)
self.visual.lastUpdate = t

self.doHooks('visualRefresh', [self, self.visual])
Expand Down Expand Up @@ -649,7 +649,7 @@ def add(self, name: str, color, endowment=None, money: bool=False, props=None):
#Add the M0 plot once we have a money good, only if we haven't done it before
elif (self.model.visual is None or self.model.visual.isNull) and hasattr(self.model.visual, 'plots'):
try:
if not 'money' in self.model.visual.plots: self.model.visual.addPlot('money', ï('Money'), selected=False)
if not 'money' in self.model.visual: self.model.visual.addPlot('money', ï('Money'), selected=False)
except: pass #Can't add plot if re-drawing the cpanel

props['quantity'] = endowment
Expand All @@ -658,7 +658,7 @@ def add(self, name: str, color, endowment=None, money: bool=False, props=None):
#Add demand plot once we have at least 2 goods
if len(self) == 2 and (self.model.visual is None or self.model.visual.isNull) and hasattr(self.model.visual, 'plots'):
try:
if not 'demand' in self.model.visual.plots: self.model.visual.addPlot('demand', ï('Demand'), selected=False)
if not 'demand' in self.model.visual: self.model.visual.addPlot('demand', ï('Demand'), selected=False)
except: pass

return item
Expand Down
103 changes: 54 additions & 49 deletions helipad/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,27 @@ def launch(self, title: str):
"""Launch the visualization window (if in Tkinter) or cell (if in Jupyter). https://helipad.dev/functions/basevisualization/launch/"""

@abstractmethod
def update(self, data: dict):
"""Update the visualization with new data and refresh the display. `data` is only data since the last visualization refresh. https://helipad.dev/functions/basevisualization/update/"""
def refresh(self, data: dict):
"""Update the visualization with new data and refresh the display. `data` is only data since the last visualization refresh. https://helipad.dev/functions/basevisualization/refresh/"""

@abstractmethod
def event(self, t: int, color, **kwargs):
"""Called when an event is triggered, in order to be reflected in the visualization. https://helipad.dev/functions/basevisualization/update/"""
"""Called when an event is triggered, in order to be reflected in the visualization. https://helipad.dev/functions/basevisualization/event/"""

def terminate(self, model):
"""Cleanup on model termination. Called automatically from `model.terminate()`. https://helipad.dev/functions/basevisualization/terminate/"""

class MPLVisualization(BaseVisualization):
class MPLVisualization(BaseVisualization, dict):
"""Base class for visualizations using Matplotlib. https://helipad.dev/functions/mplvisualization/"""
keys = {}

def __init__(self, model):
self.model = model #Unhappy with this
self.plots = {}
self.selector = model.params.add('plots', ï('Plots'), 'checkgrid', [], opts={}, runtime=False, config=True)
self.dim = None
self.pos = (400, 0)
self.fig = None
self.lastUpdate = None
self.keyListeners = {}

def pause(model, event):
if model.hasModel and event.canvas is self.fig.canvas:
Expand All @@ -60,19 +59,32 @@ def pause(model, event):
get_ipython().magic('matplotlib widget')
else: matplotlib.use('TkAgg') #macosx would be preferable (Retina support), but it blocks the cpanel while running

def __repr__(self): return f'<{self.__class__.__name__} with {len(self.plots)} plots>'
def __repr__(self): return f'<{self.__class__.__name__} with {len(self)} plots>'

# Subclasses should call super().launch **after** the figure is created.
@abstractmethod
def launch(self, title: str, dim=None, pos=None):
if not isNotebook():
self.fig.canvas.manager.set_window_title(title)
if self.model.cpanel: self.model.cpanel.setAppIcon()

#MPL can't take a method bound to an unhashable class (i.e. `dict`)
def sendEvent(event):
axes = event.artist.axes if hasattr(event, 'artist') else event.inaxes
if axes is not None:
for p in self.activePlots.values():
if axes is p.axes:
p.MPLEvent(event)
break

if event.name=='key_press_event' and event.key in self.keyListeners:
for f in self.keyListeners[event.key]: f(self.model, event)

self.fig.tight_layout()
self.fig.canvas.mpl_connect('close_event', self.model.terminate)
self.fig.canvas.mpl_connect('key_press_event', self.sendEvent)
self.fig.canvas.mpl_connect('pick_event', self.sendEvent)
self.fig.canvas.mpl_connect('button_press_event', self.sendEvent)
self.fig.canvas.mpl_connect('key_press_event', sendEvent)
self.fig.canvas.mpl_connect('pick_event', sendEvent)
self.fig.canvas.mpl_connect('button_press_event', sendEvent)
self.lastUpdate = 0

#Resize and position graph window if applicable
Expand Down Expand Up @@ -100,32 +112,26 @@ def terminate(self, model):
pos = fm.window.wm_geometry().split('+')
self.pos = (pos[1], pos[2])

def sendEvent(self, event):
"""Execute functions registered with `MPLVisualization.addKeypress()` and route other events to the appropriate `ChartPlot` object depending on the current mouse position. https://helipad.dev/functions/mplvisualization/sendevent/"""
axes = event.artist.axes if hasattr(event, 'artist') else event.inaxes
if axes is not None:
for p in self.activePlots.values():
if axes is p.axes:
p.MPLEvent(event)
break

if event.name=='key_press_event' and event.key in self.keys:
for f in self.keys[event.key]: f(self.model, event)

def addKeypress(self, key: str, fn):
"""Register a function to be run when `key` is pressed in a Matplotlib visualizer. `fn` will run if `key` is pressed at any time when the plot window is in focus. To narrow the focus to a particular plot, define `catchKeypress()` in a subclass of `ChartPlot`. https://helipad.dev/functions/mplvisualization/addkeypress/"""
if not key in self.keys: self.keys[key] = []
self.keys[key].append(fn)
if not key in self.keyListeners: self.keyListeners[key] = []
self.keyListeners[key].append(fn)

@property
def activePlots(self) -> dict:
"""The subset of the plots containing the `Plot`s that are currently active. https://helipad.dev/functions/mplvisualization/#activePlots"""
return {k:plot for k,plot in self.plots.items() if plot.selected}
return {k:plot for k,plot in self.items() if plot.selected}

@property
def isNull(self) -> bool:
"""`True` when the model should run as-if with no visualization, for example if all plots are unselected. `False` indicates the window can be launched. https://helipad.dev/functions/mplvisualization/#isNull"""
return not [plot for plot in self.plots.values() if plot.selected]
return not [plot for plot in self.values() if plot.selected]

@property
def plots(self):
"""A `dict` of plots. This property is deprecated; the visualization object can be indexed directly."""
warnings.warn(ï('model.visual.plots is deprecated. Plots can be accessed by indexing the visualization object directly.'), FutureWarning, 2)
return self

class TimeSeries(MPLVisualization):
"""A Matplotlib-based visualizer for displaying time series data on plots so the whole history of any given variable over the model's runtime can be seen at once. https://helipad.dev/functions/timeseries/"""
Expand All @@ -151,23 +157,22 @@ def toggle(model, event):
#Delete the corresponding series when a reporter is removed
@model.hook('removeReporter')
def deleteSeries(data, key):
for p in self.plots.values():
for p in self.values():
for s in p.series:
if s.reporter==key:
# Remove subseries
for ss in s.subseries:
for sss in self.model.plots[s.plot].series:
for sss in self[s.plot].series:
if sss.reporter == ss:
self.plots[s.plot].series.remove(sss)
self[s.plot].series.remove(sss)
continue
self.plots[s.plot].series.remove(s)
self[s.plot].series.remove(s)

#Move the plots parameter to the end when the cpanel launches
@model.hook('CpanelPreLaunch')
def movePlotParam(model):
model.params['plots'] = model.params.pop('plots')

#listOfPlots is the trimmed model.plots list
def launch(self, title: str):
if not self.activePlots: return #Windowless mode

Expand Down Expand Up @@ -198,7 +203,7 @@ def terminate(self, model):
model.params['csv'].enable()
if not isinstance(model.param('stopafter'), str): model.params['stopafter'].enable()

def update(self, data: dict):
def refresh(self, data: dict):
newlen = len(next(data[x] for x in data))
if self.resolution > 1: data = {k: keepEvery(v, self.resolution) for k,v in data.items()}
time = newlen + len(next(iter(self.activePlots.values())).series[0].fdata)*self.resolution
Expand All @@ -221,15 +226,16 @@ def addPlot(self, name: str, label: str, position=None, selected: bool=True, log
plot = TimeSeriesPlot(viz=self, name=name, label=label, logscale=logscale, stack=stack)

self.selector.addItem(name, label, position, selected)
if position is None or position > len(self.plots): self.plots[name] = plot
if position is None or position > len(self): self[name] = plot
else: #Reconstruct the dicts because there's no insert method…
newplots, i = ({}, 1)
for k,v in self.plots.items():
for k,v in self.items():
if position==i:
newplots[name] = plot
newplots[k] = v
i+=1
self.plots = newplots
self.clear()
self.update(newplots)

plot.selected = selected #Do this after CheckgridParam.addItem
return plot
Expand All @@ -241,12 +247,12 @@ def removePlot(self, name, reassign=None):
for p in name: self.removePlot(p, reassign)
return

if not name in self.plots:
if not name in self:
warnings.warn(ï('No plot \'{}\' to remove.').format(name), None, 2)
return False

if reassign is not None: self.plots[reassign].series += self.plots[name].series
del self.plots[name]
if reassign is not None: self[reassign].series += self[name].series
del self[name]
del self.selector.opts[name]
del self.selector.vars[name]
if name in self.selector.default: self.selector.default.remove(name)
Expand All @@ -257,7 +263,7 @@ def event(self, t: int, color='#CC0000', linestyle: str='--', linewidth=1, **kwa
self.verticals.append([p.axes.axvline(x=t, color=color, linestyle=linestyle, linewidth=linewidth) for p in self.activePlots.values()])

# Problem: Need x to be in plot coordinates but y to be absolute w.r.t the figure
# next(iter(self.plots.values())).axes.text(t, 0, label, horizontalalignment='center')
# next(iter(self.values())).axes.text(t, 0, label, horizontalalignment='center')

class Charts(MPLVisualization):
"""A Matplotlib-based visualizer for a variety of visualizations that display data that reflects a single point in time. https://helipad.dev/functions/charts/"""
Expand All @@ -268,7 +274,6 @@ def __init__(self, model):

for p in [BarChart, AgentsPlot, TimeSeriesPlot]: self.addPlotType(p)
model.params['refresh'].runtime=False
self.refresh = model.params['refresh']
self.model = model # :(

def launch(self, title: str):
Expand All @@ -291,7 +296,7 @@ def launch(self, title: str):
super().launch(title)

#Time slider
ref = self.refresh.get()
ref = self.model.params['refresh'].get()
self.fig.subplots_adjust(bottom=0.12) #Make room for the slider
sax = self.fig.add_axes([0.1,0.01,.75,0.03], facecolor='#EEF')
self.timeslider = Slider(sax, 't=', 0, ref, ref, valstep=ref, closedmin=False)
Expand All @@ -300,7 +305,7 @@ def launch(self, title: str):
self.fig.canvas.draw_idle()
plt.show(block=False)

def update(self, data: dict):
def refresh(self, data: dict):
data = {k:v[-1] for k,v in data.items()}
t = self.model.t #cheating?
for c in self.activePlots.values(): c.update(data, t)
Expand Down Expand Up @@ -328,10 +333,10 @@ def addPlot(self, name: str, label: str, type=None, position=None, selected=True
type = 'agents'
self.type = type if type is not None else 'bar'
if self.type not in self.plotTypes: raise KeyError(ï('\'{}\' is not a registered plot visualizer.').format(self.type))
self.plots[name] = self.plotTypes[self.type](name=name, label=label, viz=self, selected=True, **kwargs)
self[name] = self.plotTypes[self.type](name=name, label=label, viz=self, selected=True, **kwargs)

self.plots[name].selected = selected #Do this after CheckgridParam.addItem
return self.plots[name]
self[name].selected = selected #Do this after CheckgridParam.addItem
return self[name]

def addPlotType(self, clss):
"""Registers a new plot type for the Charts visualizer. Registered plot types can then be added to the visualization area with `Charts.addPlot()`. https://helipad.dev/functions/charts/addplottype/"""
Expand All @@ -345,15 +350,15 @@ def removePlot(self, name):
for p in name: self.removePlot(p)
return

if not name in self.plots:
if not name in self:
warnings.warn(ï('No plot \'{}\' to remove.').format(name), None, 2)
return False

del self.plots[name]
del self[name]
return True

def event(self, t: int, color='#FDC', **kwargs):
ref = self.refresh.get()
ref = self.model.params['refresh'].get()
self.events[ceil(t/ref)*ref] = color

#======================
Expand Down Expand Up @@ -407,7 +412,7 @@ def MPLEvent(self, event):
"""Catch `pick_event`, `key_press_event`, and `button_press_event` events that occur inside a particular plot. https://helipad.dev/functions/chartplot/mplevent/"""

class TimeSeriesPlot(ChartPlot):
"""Visualizes time series data on one or more variables. Can be used in either the `TimeSeries` or `Charts` visualizer. https://helipad.dev/functions/timeseriesplot/"""
"""Visualize time series data on one or more variables. Can be used in either the `TimeSeries` or `Charts` visualizer. https://helipad.dev/functions/timeseriesplot/"""
type = 'timeseries'
def __init__(self, **kwargs):
self.series = []
Expand Down
18 changes: 9 additions & 9 deletions sample-models/Cities.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def tmp(model):
viz.addPlot('wealth', 'Wealth', 4)
viz.addPlot('rates', 'Rates', 5, logscale=True)
heli.data.addReporter('theta', heli.data.modelReporter('deathrate'), smooth=0.99)
viz.plots['rates'].addSeries('theta', 'Death Rate', '#CCCCCC')
viz['rates'].addSeries('theta', 'Death Rate', '#CCCCCC')

for breed, d in heli.agents['agent'].breeds.items():
heli.data.addReporter(breed+'Pop', heli.land[breed].pop)
Expand All @@ -218,21 +218,21 @@ def tmp(model):
heli.data.addReporter(breed+'Wealth', heli.data.agentReporter('wealth', 'agent', breed=breed, percentiles=[25,75]))
heli.data.addReporter(breed+'moveRate', heli.data.modelReporter('moverate'+breed), smooth=0.99)
heli.data.addReporter(breed+'birthrate', heli.data.modelReporter('birthrate'+breed), smooth=0.99)
viz.plots['pop'].addSeries(breed+'Pop', breed.title()+' Population', d.color)
viz.plots['hcap'].addSeries(breed+'H', breed.title()+' Human Capital', d.color)
viz.plots['wage'].addSeries(breed+'Wage', breed.title()+' Wage', d.color)
viz.plots['wage'].addSeries(breed+'ExpWage', breed.title()+' Expected Wage', d.color2, visible=False)
viz.plots['wealth'].addSeries(breed+'Wealth', breed.title()+' Wealth', d.color)
viz.plots['rates'].addSeries(breed+'moveRate', breed.title()+' Moveaway Rate', d.color2)
viz.plots['rates'].addSeries(breed+'birthrate', breed.title()+' Birthrate', d.color)
viz['pop'].addSeries(breed+'Pop', breed.title()+' Population', d.color)
viz['hcap'].addSeries(breed+'H', breed.title()+' Human Capital', d.color)
viz['wage'].addSeries(breed+'Wage', breed.title()+' Wage', d.color)
viz['wage'].addSeries(breed+'ExpWage', breed.title()+' Expected Wage', d.color2, visible=False)
viz['wealth'].addSeries(breed+'Wealth', breed.title()+' Wealth', d.color)
viz['rates'].addSeries(breed+'moveRate', breed.title()+' Moveaway Rate', d.color2)
viz['rates'].addSeries(breed+'birthrate', breed.title()+' Birthrate', d.color)

heli.launchCpanel()

#================
# PARAM SWEEP & ANALYSIS
#================

# for p in viz.plots.values(): p.active(False) #Disable plots so Helipad doesn't try to update the visuals during param sweep
# for p in viz.values(): p.active(False) #Disable plots so Helipad doesn't try to update the visuals during param sweep
# #Remove superfluous columns
# @heli.hook
# def saveCSV(data, model):
Expand Down
Loading

0 comments on commit 84f0924

Please sign in to comment.