Skip to content

Commit

Permalink
Merge pull request upb-lea#69 from upb-lea/issue_#55
Browse files Browse the repository at this point in the history
  • Loading branch information
wallscheid authored Jul 22, 2020
2 parents 690e64a + 530b33e commit 3fb1891
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 90 deletions.
2 changes: 1 addition & 1 deletion examples/ddpg_pmsm_dq_current_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def reset(self, **kwargs):
env = gem.make(
'emotor-pmsm-cont-v1',
# Pass a class with extra parameters
visualization=MotorDashboard(plots=['i_sq', 'i_sd', 'action_1', 'action_0']), visu_period=1,
visualization=MotorDashboard(plots=['i_sq', 'i_sd', 'action_0', 'action_1', 'mean_reward']), visu_period=1,
load=ConstantSpeedLoad(omega_fixed=1000 * np.pi / 30),
control_space='dq',
# Pass a string (with extra parameters)
Expand Down
2 changes: 1 addition & 1 deletion examples/pi_series_omega_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
'DcSeriesCont-v1', # replace with 'DcSeriesDisc-v1' for discrete controllers
# Pass an instance
#visualization=MotorDashboard(plotted_variables='all', visu_period=1),
visualization=MotorDashboard(plots=['omega','reward', 'i'], dark_mode=True),
visualization=MotorDashboard(plots=['i', 'reward', 'action_0', 'mean_reward'], dark_mode=False),
motor_parameter=dict(r_a=15e-3, r_e=15e-3, l_a=1e-3, l_e=1e-3),
# Take standard class and pass parameters (Load)
load_parameter=dict(a=0.01, b=.1, c=0.1, j_load=.06),
Expand Down
70 changes: 48 additions & 22 deletions gym_electric_motor/visualization/motor_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from gym_electric_motor.core import ElectricMotorVisualization
from . import motor_dashboard_plots as mdp
import matplotlib
import matplotlib.pyplot as plt
import collections


class MotorDashboard(ElectricMotorVisualization):
"""Dashboard to plot the GEM states into graphs.
Expand All @@ -22,21 +21,26 @@ def __init__(self, plots, update_cycle=1000, dark_mode=False, **__):
- {state_name}: The corresponding state is plotted
- reward: The reward per step is plotted
- action_{i}: The i-th action is plotted. 'action_0' for discrete action space
- mean_reward: The mean episode reward
update_cycle(int): Number after how many steps the plot shall be updated. (default 1000)
dark_mode(Bool): Select a dark background for visualization by setting it to True
"""
plt.ion()
self._update_cycle = update_cycle
self._figure = None
self._figure_ep = None
self._plots = []
self._episode_plots = []

self._dark_mode = dark_mode
for plot in plots:
if type(plot) is str:
if plot == 'reward':
self._plots.append(mdp.RewardPlot())
elif plot.startswith('action_'):
self._plots.append(mdp.ActionPlot(plot))
elif plot.startswith('mean_reward'):
self._episode_plots.append(mdp.MeanEpisodeRewardPlot())
else:
self._plots.append(mdp.StatePlot(plot))
else:
Expand All @@ -46,8 +50,16 @@ def __init__(self, plots, update_cycle=1000, dark_mode=False, **__):
def reset(self, **__):
"""Called when the environment is reset. All subplots are reset.
"""
for plot in self._plots:

for plot in self._plots: # for plot in self._plots + self._episode_plots throws warning
plot.reset()
for plot in self._episode_plots:
plot.reset()

# since end of an episode can only be identified by a reset call. Episode based plot canvas updated here
if self._figure_ep:
self._figure_ep.canvas.draw()
self._figure_ep.canvas.flush_events()

def step(self, k, state, reference, action, reward, done):
""" Called within a render() call of an environment.
Expand All @@ -62,10 +74,13 @@ def step(self, k, state, reference, action, reward, done):
reward(ndarray(float)): Last received reward. (None after reset)
done(bool): Flag if the current state is terminal
"""
if not self._figure:
if not (self._figure or self._figure_ep):
self._initialize()
for plot in self._plots:
for plot in self._plots : # for plot in self._plots + self._episode_plots throws warning
plot.step(k, state, reference, action, reward, done)
for plot in self._episode_plots:
plot.step(k, state, reference, action, reward, done)

if (k + 1) % self._update_cycle == 0:
self._update()

Expand All @@ -78,37 +93,48 @@ def set_modules(self, ps, rg, rf):
rg(ReferenceGenerator): ReferenceGenerator of the environment
rf(RewardFunction): RewardFunction of the environment
"""
for plot in self._plots:
for plot in self._plots: # for plot in self._plots + self._episode_plots throws warning
plot.set_modules(ps, rg, rf)
for plot in self._episode_plots:
plot.set_modules(ps, rg, rf)

def _initialize(self):
"""Called with first render() call to setup the figures and plots.
"""
axis_ep = []
plt.close()
assert len(self._plots)>0, "no plot variable selected"
# For the dark background lovers
assert len(self._plots) > 0, "no plot variable selected"
# Use dark-mode, if selected
if self._dark_mode:
plt.style.use('dark_background')
# create seperate figures for time based and episode based plots
self._figure, axes = plt.subplots(len(self._plots), sharex=True)
self._figure.subplots_adjust(wspace=0.0, hspace=0.2)
#plt.style.use("dark_background")
plt.xlabel('t/s') # adding a common x-label to all the subplots
if self._episode_plots:
self._figure_ep, axes_ep = plt.subplots(len(self._episode_plots))
self._figure_ep.subplots_adjust(wspace=0.0, hspace=0.02)
self._figure.subplots_adjust(wspace=0.0, hspace=0.2)
self._figure_ep.text(0.5, 0.04, 'episode', va='center', ha='center')

# adding a common x-label to all the subplots in each figure
self._figure.text(0.5, 0.04, 't/s', va='center', ha='center')

#plt.subplot() does not return an iterable var when the number of subplots==1
if len(self._plots) < 2:
self._plots[0].initialize(axes)
plt.pause(0.1)
else:
# plt.subplot() does not return an iterable var when the number of subplots==1
if len(self._plots) == 1:
axes = [axes]
if len(self._episode_plots) == 1:
axis_ep = [axes_ep]
for plot, axis in zip(self._plots, axes):
plot.initialize(axis)

for plot, axis in zip(self._plots, axes):
plot.initialize(axis)
plt.pause(0.1)
for plot, axis in zip(self._episode_plots, axis_ep):
plot.initialize(axis)
plt.pause(0.1)

def _update(self):
"""Called every {update_cycle} steps to refresh the figure.
"""
for plot in self._plots:
plot.update()
if matplotlib.get_backend() == 'NbAgg':
self._figure.canvas.draw()

self._figure.canvas.draw()
self._figure.canvas.flush_events()
Loading

0 comments on commit 3fb1891

Please sign in to comment.