From a0c8a020037e1929d2c3452691e91f8f096ae0de Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 29 Mar 2017 16:59:16 +0800 Subject: [PATCH] add DISABLE_PLOT for test --- python/paddle/v2/{ => plot}/plot_curve.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) rename python/paddle/v2/{ => plot}/plot_curve.py (60%) diff --git a/python/paddle/v2/plot_curve.py b/python/paddle/v2/plot/plot_curve.py similarity index 60% rename from python/paddle/v2/plot_curve.py rename to python/paddle/v2/plot/plot_curve.py index 178506bbebdf3..0f62674cb2baa 100644 --- a/python/paddle/v2/plot_curve.py +++ b/python/paddle/v2/plot/plot_curve.py @@ -1,5 +1,5 @@ -import matplotlib.pyplot as plt from IPython import display +import os class PlotCost(object): @@ -11,18 +11,29 @@ def __init__(self): self.train_costs = ([], []) self.test_costs = ([], []) + self.__disable_plot__ = os.environ.get("DISABLE_PLOT") + if not self.__plot_is_disabled__(): + import matplotlib.pyplot as plt + self.plt = plt + + def __plot_is_disabled__(self): + return self.__disable_plot__ == "True" + def plot(self): - plt.plot(*self.train_costs) - plt.plot(*self.test_costs) + if self.__plot_is_disabled__(): + return + + self.plt.plot(*self.train_costs) + self.plt.plot(*self.test_costs) title = [] if len(self.train_costs[0]) > 0: title.append('Train Cost') if len(self.test_costs[0]) > 0: title.append('Test Cost') - plt.legend(title, loc='upper left') + self.plt.legend(title, loc='upper left') display.clear_output(wait=True) - display.display(plt.gcf()) - plt.gcf().clear() + display.display(self.plt.gcf()) + self.plt.gcf().clear() def append_train_cost(self, step, cost): self.train_costs[0].append(step)