Skip to content

Commit

Permalink
Added a way to post-process subplots.
Browse files Browse the repository at this point in the history
  • Loading branch information
gbeced committed Jul 27, 2016
1 parent 2cc626b commit 5989252
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions pyalgotrade/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def includeDateTime(self, dateTime):
return filter(lambda x: dateTimeFilter.includeDateTime(x), dateTimes)


def _post_plot_fun(subPlot, mplSubplot):
# Legend
mplSubplot.legend(subPlot.getAllSeries().keys(), shadow=True, loc="best")
# Don't scale the Y axis
mplSubplot.yaxis.set_major_formatter(ticker.ScalarFormatter(useOffset=False))


class Series(object):
def __init__(self):
self.__values = {}
Expand Down Expand Up @@ -215,6 +222,9 @@ def __getColor(self, series):
def isEmpty(self):
return len(self.__series) == 0

def getAllSeries(self):
return self.__series

def addDataSeries(self, label, dataSeries, defaultClass=LineMarker):
"""Add a DataSeries to the subplot.
Expand Down Expand Up @@ -261,20 +271,14 @@ def getSeries(self, name, defaultClass=LineMarker):
def getCustomMarksSeries(self, name):
return self.getSeries(name, CustomMarker)

def customizeSubplot(self, mplSubplot):
# Don't scale the Y axis
mplSubplot.yaxis.set_major_formatter(ticker.ScalarFormatter(useOffset=False))

def plot(self, mplSubplot, dateTimes):
def plot(self, mplSubplot, dateTimes, postPlotFun=_post_plot_fun):
for series in self.__series.values():
color = None
if series.needColor():
color = self.__getColor(series)
series.plot(mplSubplot, dateTimes, color)

# Legend
mplSubplot.legend(self.__series.keys(), shadow=True, loc="best")
self.customizeSubplot(mplSubplot)
postPlotFun(self, mplSubplot)


class InstrumentSubplot(Subplot):
Expand Down Expand Up @@ -397,7 +401,7 @@ def getPortfolioSubplot(self):
"""
return self.__portfolioSubplot

def __buildFigureImpl(self, fromDateTime=None, toDateTime=None):
def __buildFigureImpl(self, fromDateTime=None, toDateTime=None, postPlotFun=_post_plot_fun):
dateTimes = _filter_datetimes(self.__dateTimes, fromDateTime, toDateTime)
dateTimes.sort()

Expand All @@ -414,7 +418,7 @@ def __buildFigureImpl(self, fromDateTime=None, toDateTime=None):
axesSubplot = axes[i][0]
if not subplot.isEmpty():
mplSubplots.append(axesSubplot)
subplot.plot(axesSubplot, dateTimes)
subplot.plot(axesSubplot, dateTimes, postPlotFun=postPlotFun)
axesSubplot.grid(True)

return (fig, mplSubplots)
Expand All @@ -426,7 +430,7 @@ def buildFigure(self, fromDateTime=None, toDateTime=None):
fig, _ = self.buildFigureAndSubplots(fromDateTime, toDateTime)
return fig

def buildFigureAndSubplots(self, fromDateTime=None, toDateTime=None):
def buildFigureAndSubplots(self, fromDateTime=None, toDateTime=None, postPlotFun=_post_plot_fun):
"""Builds a matplotlib.figure.Figure with the subplots. Must be called after running the strategy.
:param fromDateTime: An optional starting datetime.datetime. Everything before it won't get plotted.
Expand All @@ -435,11 +439,11 @@ def buildFigureAndSubplots(self, fromDateTime=None, toDateTime=None):
:type toDateTime: datetime.datetime
:rtype: A 2 element tuple with matplotlib.figure.Figure and subplots.
"""
fig, mplSubplots = self.__buildFigureImpl(fromDateTime, toDateTime)
fig, mplSubplots = self.__buildFigureImpl(fromDateTime, toDateTime, postPlotFun=postPlotFun)
fig.autofmt_xdate()
return fig, mplSubplots

def plot(self, fromDateTime=None, toDateTime=None):
def plot(self, fromDateTime=None, toDateTime=None, postPlotFun=_post_plot_fun):
"""Plots the strategy execution. Must be called after running the strategy.
:param fromDateTime: An optional starting datetime.datetime. Everything before it won't get plotted.
Expand All @@ -448,6 +452,6 @@ def plot(self, fromDateTime=None, toDateTime=None):
:type toDateTime: datetime.datetime
"""

fig, mplSubplots = self.__buildFigureImpl(fromDateTime, toDateTime)
fig, mplSubplots = self.__buildFigureImpl(fromDateTime, toDateTime, postPlotFun=postPlotFun)
fig.autofmt_xdate()
plt.show()

0 comments on commit 5989252

Please sign in to comment.