Skip to content

Commit

Permalink
Make sure that new-style viewers round-trip when saved to session files
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Apr 19, 2017
1 parent 6b0a5c8 commit 9ae7228
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 74 deletions.
11 changes: 7 additions & 4 deletions glue/viewers/common/mpl_layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ def redraw(self):

@property
def zorder(self):
return self.layer_state.zorder
return self.state.zorder

@zorder.setter
def zorder(self, value):
self.layer_state.zorder = value
self.state.zorder = value

@property
def visible(self):
return self.layer_state.visible
return self.state.visible

@visible.setter
def visible(self, value):
self.layer_state.visible = value
self.state.visible = value

def __gluestate__(self, context):
return dict(state=context.id(self.state))
55 changes: 24 additions & 31 deletions glue/viewers/common/qt/mpl_data_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,24 @@ def __init__(self, session, parent=None):

# Set up the state which will contain everything needed to represent
# the current state of the viewer
self.viewer_state = self._state_cls()
self.viewer_state.data_collection = session.data_collection
self.state = self._state_cls()
self.state.data_collection = session.data_collection

# Set up the options widget, which will include options that control the
# viewer state
self.options = self._options_cls(viewer_state=self.viewer_state,
self.options = self._options_cls(viewer_state=self.state,
session=session)

add_callback(self.viewer_state, 'x_min', nonpartial(self.limits_to_mpl))
add_callback(self.viewer_state, 'x_max', nonpartial(self.limits_to_mpl))
add_callback(self.viewer_state, 'y_min', nonpartial(self.limits_to_mpl))
add_callback(self.viewer_state, 'y_max', nonpartial(self.limits_to_mpl))
add_callback(self.state, 'x_min', nonpartial(self.limits_to_mpl))
add_callback(self.state, 'x_max', nonpartial(self.limits_to_mpl))
add_callback(self.state, 'y_min', nonpartial(self.limits_to_mpl))
add_callback(self.state, 'y_max', nonpartial(self.limits_to_mpl))

self.axes.callbacks.connect('xlim_changed', nonpartial(self.limits_from_mpl))
self.axes.callbacks.connect('ylim_changed', nonpartial(self.limits_from_mpl))

self.viewer_state.add_callback('log_x', nonpartial(self.update_log_x))
self.viewer_state.add_callback('log_y', nonpartial(self.update_log_y))
self.state.add_callback('log_x', nonpartial(self.update_log_x))
self.state.add_callback('log_y', nonpartial(self.update_log_y))

self.axes.set_autoscale_on(False)

Expand All @@ -65,37 +65,37 @@ def __init__(self, session, parent=None):

# And vice-versa when layer states are removed from the viewer state, we
# need to keep the layer_artist_container in sync
self.viewer_state.add_callback('layers', nonpartial(self._sync_layer_artist_container))
self.state.add_callback('layers', nonpartial(self._sync_layer_artist_container))

def _sync_state_layers(self):
# Remove layer state objects that no longer have a matching layer
for layer_state in self.viewer_state.layers:
for layer_state in self.state.layers:
if layer_state.layer not in self._layer_artist_container:
self.viewer_state.layers.remove(layer_state)
self.state.layers.remove(layer_state)

def _sync_layer_artist_container(self):
# Remove layer artists that no longer have a matching layer state
layer_states = set(layer_state.layer for layer_state in self.viewer_state.layers)
layer_states = set(layer_state.layer for layer_state in self.state.layers)
for layer_artist in self._layer_artist_container:
if layer_artist.layer not in layer_states:
self._layer_artist_container.remove(layer_artist)

def update_log_x(self):
self.axes.set_xscale('log' if self.viewer_state.log_x else 'linear')
self.axes.set_xscale('log' if self.state.log_x else 'linear')

def update_log_y(self):
self.axes.set_yscale('log' if self.viewer_state.log_y else 'linear')
self.axes.set_yscale('log' if self.state.log_y else 'linear')

@avoid_circular
def limits_from_mpl(self):
# TODO: delay callbacks here
self.viewer_state.x_min, self.viewer_state.x_max = self.axes.get_xlim()
self.viewer_state.y_min, self.viewer_state.y_max = self.axes.get_ylim()
self.state.x_min, self.state.x_max = self.axes.get_xlim()
self.state.y_min, self.state.y_max = self.axes.get_ylim()

@avoid_circular
def limits_to_mpl(self):
self.axes.set_xlim(self.viewer_state.x_min, self.viewer_state.x_max)
self.axes.set_ylim(self.viewer_state.y_min, self.viewer_state.y_max)
self.axes.set_xlim(self.state.x_min, self.state.x_max)
self.axes.set_ylim(self.state.y_min, self.state.y_max)
self.axes.figure.canvas.draw()

# TODO: shouldn't need this!
Expand All @@ -113,7 +113,7 @@ def add_data(self, data):
raise IncompatibleDataException("Data not in DataCollection")

# Create layer artist and add to container
layer = self._data_artist_cls(data, self._axes, self.viewer_state)
layer = self._data_artist_cls(self._axes, self.state, layer=data)
self._layer_artist_container.append(layer)
layer.update()

Expand All @@ -126,13 +126,13 @@ def add_data(self, data):
@defer_draw
def remove_data(self, data):

for layer_artist in self.viewer_state.layers[::-1]:
for layer_artist in self.state.layers[::-1]:
if isinstance(layer_artist.layer, Data):
if layer_artist.layer is data:
self.viewer_state.layers.remove(layer_artist)
self.state.layers.remove(layer_artist)
else:
if layer_artist.layer.data is data:
self.viewer_state.layers.remove(layer_artist)
self.state.layers.remove(layer_artist)

@defer_draw
def add_subset(self, subset):
Expand All @@ -143,15 +143,8 @@ def add_subset(self, subset):
self.add_data(subset.data)
return

# Copy settings from data if present
if subset.data in self._layer_artist_container:
initial_layer_state = self._layer_artist_container[subset.data][0].layer_state
else:
initial_layer_state = None

# Create scatter layer artist and add to container
layer = self._subset_artist_cls(subset, self._axes, self.viewer_state,
initial_layer_state=initial_layer_state)
layer = self._subset_artist_cls(self._axes, self.state, layer=subset)
self._layer_artist_container.append(layer)
layer.update()

Expand Down
25 changes: 25 additions & 0 deletions glue/viewers/common/qt/tests/test_mpl_data_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from glue.core import Data
from glue.core.tests.util import simple_session
from glue.core.exceptions import IncompatibleDataException
from glue.app.qt.application import GlueApplication


class MatplotlibDrawCounter(object):
Expand Down Expand Up @@ -414,3 +415,27 @@ def test_subset_remove_message(self):
assert sub in self.viewer._layer_artist_container
sub.delete()
assert sub not in self.viewer._layer_artist_container

def test_session_round_trip(self, tmpdir):

self.init_subset()

ga = GlueApplication(self.data_collection)
ga.show()

viewer = ga.new_data_viewer(self.viewer_cls)
viewer.add_data(self.data)

session_file = tmpdir.join('test_session_round_trip.glu').strpath
ga.save_session(session_file)
ga.close()

ga2 = GlueApplication.restore_session(session_file)
ga2.show()

viewer2 = ga2.viewers[0][0]

data2 = ga2.data_collection[0]

assert viewer2.layers[0].layer is data2
assert viewer2.layers[1].layer is data2.subsets[0]
42 changes: 18 additions & 24 deletions glue/viewers/histogram_new/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,23 @@

class HistogramLayerArtist(MatplotlibLayerArtist):

def __init__(self, layer, axes, viewer_state, initial_layer_state=None):
def __init__(self, axes, viewer_state, layer_state=None, layer=None):

super(HistogramLayerArtist, self).__init__(layer, axes, viewer_state)

# Set up a state object for the layer artist
if initial_layer_state is None:
initial = {}
else:
initial = initial_layer_state.as_dict()
if 'layer' in initial:
initial.pop('layer')
self.layer = layer or layer_state.layer

# Set up a state object for the layer artist
self.layer_state = HistogramLayerState(viewer_state=viewer_state, layer=layer, **initial)
self.viewer_state.layers.append(self.layer_state)
self.state = layer_state or HistogramLayerState(viewer_state=viewer_state, layer=self.layer)
self.viewer_state.layers.append(self.state)

# Watch for changes in the viewer state which would require the
# layers to be redrawn
self.viewer_state.add_callback('*', self._update_histogram, as_kwargs=True)
self.layer_state.add_callback('*', self._update_histogram, as_kwargs=True)
self.state.add_callback('*', self._update_histogram, as_kwargs=True)

# TODO: following is temporary
self.layer_state.data_collection = self.viewer_state.data_collection
self.state.data_collection = self.viewer_state.data_collection
self.data_collection = self.viewer_state.data_collection

self.reset_cache()
Expand Down Expand Up @@ -101,18 +95,18 @@ def _scale_histogram(self):
#
# because this would never allow y_max to get smaller.

self.layer_state._y_max = self.mpl_hist.max()
self.state._y_max = self.mpl_hist.max()

if self.viewer_state.log_y:
self.layer_state._y_max *= 2
self.state._y_max *= 2
else:
self.layer_state._y_max *= 1.2
self.state._y_max *= 1.2

for layer in self.viewer_state.layers:
if self.layer_state != layer and hasattr(layer, '_y_max') and self.layer_state._y_max < layer._y_max:
if self.state != layer and hasattr(layer, '_y_max') and self.state._y_max < layer._y_max:
break
else:
self.viewer_state.y_max = self.layer_state._y_max
self.viewer_state.y_max = self.state._y_max

if self.viewer_state.log_y:
self.viewer_state.y_min = self.mpl_hist[self.mpl_hist > 0].min() / 10
Expand All @@ -125,11 +119,11 @@ def _scale_histogram(self):
def _update_visual_attributes(self):

for mpl_artist in self.mpl_artists:
mpl_artist.set_visible(self.layer_state.visible)
mpl_artist.set_zorder(self.layer_state.zorder)
mpl_artist.set_visible(self.state.visible)
mpl_artist.set_zorder(self.state.zorder)
mpl_artist.set_edgecolor('none')
mpl_artist.set_facecolor(self.layer_state.color)
mpl_artist.set_alpha(self.layer_state.alpha)
mpl_artist.set_facecolor(self.state.color)
mpl_artist.set_alpha(self.state.alpha)

self.redraw()

Expand All @@ -139,7 +133,7 @@ def _update_histogram(self, force=False, **kwargs):
self.viewer_state.hist_x_max is None or
self.viewer_state.hist_n_bin is None or
self.viewer_state.xatt is None or
self.layer_state.layer is None):
self.state.layer is None):
return

# Figure out which attributes are different from before. Ideally we shouldn't
Expand All @@ -157,12 +151,12 @@ def _update_histogram(self, force=False, **kwargs):
if value != self._last_viewer_state.get(key, None):
changed.add(key)

for key, value in self.layer_state.as_dict().items():
for key, value in self.state.as_dict().items():
if value != self._last_layer_state.get(key, None):
changed.add(key)

self._last_viewer_state.update(self.viewer_state.as_dict())
self._last_layer_state.update(self.layer_state.as_dict())
self._last_layer_state.update(self.state.as_dict())

if force or any(prop in changed for prop in ('layer', 'xatt', 'hist_x_min', 'hist_x_max', 'hist_n_bin', 'log_x')):
self._calculate_histogram()
Expand Down
61 changes: 49 additions & 12 deletions glue/viewers/histogram_new/qt/data_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from glue.viewers.histogram_new.layer_artist import HistogramLayerArtist
from glue.viewers.histogram_new.qt.options_widget import HistogramOptionsWidget
from glue.viewers.histogram_new.state import HistogramViewerState
from glue.viewers.histogram_new.compat import update_viewer_state

from glue.core.state import lookup_class_with_patches

__all__ = ['HistogramViewer']

Expand All @@ -31,24 +34,24 @@ class HistogramViewer(MatplotlibDataViewer):

def __init__(self, session, parent=None):
super(HistogramViewer, self).__init__(session, parent)
self.viewer_state.add_callback('xatt', nonpartial(self._update_axes))
self.viewer_state.add_callback('log_x', nonpartial(self._update_axes))
self.viewer_state.add_callback('normalize', nonpartial(self._update_axes))
self.state.add_callback('xatt', nonpartial(self._update_axes))
self.state.add_callback('log_x', nonpartial(self._update_axes))
self.state.add_callback('normalize', nonpartial(self._update_axes))

@defer_draw
def _update_axes(self):

if self.viewer_state.xatt is not None:
if self.state.xatt is not None:

# Update ticks, which sets the labels to categories if components are categorical
update_ticks(self.axes, 'x', self.viewer_state._get_x_components(), False)
update_ticks(self.axes, 'x', self.state._get_x_components(), False)

if self.viewer_state.log_x:
self.axes.set_xlabel('Log ' + self.viewer_state.xatt.label)
if self.state.log_x:
self.axes.set_xlabel('Log ' + self.state.xatt.label)
else:
self.axes.set_xlabel(self.viewer_state.xatt.label)
self.axes.set_xlabel(self.state.xatt.label)

if self.viewer_state.normalize:
if self.state.normalize:
self.axes.set_ylabel('Normalized number')
else:
self.axes.set_ylabel('Number')
Expand All @@ -68,7 +71,7 @@ def apply_roi(self, roi):
# Expand roi to match bin edges
# TODO: make this an option

bins = self.viewer_state.bins
bins = self.state.bins

x = roi.to_polygon()[0]
lo, hi = min(x), max(x)
Expand All @@ -85,10 +88,44 @@ def apply_roi(self, roi):
if not isinstance(layer_artist.layer, Data):
continue

x_comp = layer_artist.layer.get_component(self.viewer_state.xatt)
x_comp = layer_artist.layer.get_component(self.state.xatt)

subset_state = x_comp.subset_from_roi(self.viewer_state.xatt, roi_new,
subset_state = x_comp.subset_from_roi(self.state.xatt, roi_new,
coord='x')

mode = EditSubsetMode()
mode.update(self._data, subset_state, focus_data=layer_artist.layer)

def __gluestate__(self, context):
return dict(state=self.state.__gluestate__(context),
session=context.id(self._session),
size=self.viewer_size,
pos=self.position,
layers=list(map(context.do, self.layers)),
_protocol=1)

@classmethod
def __setgluestate__(cls, rec, context):

if rec.get('_protocol', 0) < 1:
update_viewer_state(rec, context)

session = context.object(rec['session'])
viewer = cls(session)
viewer.register_to_hub(session.hub)
viewer.viewer_size = rec['size']
x, y = rec['pos']
viewer.move(x=x, y=y)

viewer_state = HistogramViewerState.__setgluestate__(rec['state'], context)
viewer.state.update_from_state(viewer_state)

# Restore layer artists
for l in rec['layers']:
cls = lookup_class_with_patches(l.pop('_type'))
layer_state = context.object(l['state'])
print(type(layer_state))
layer_artist = cls(axes=viewer.axes, viewer_state=viewer.state, layer_state=layer_state)
viewer._layer_artist_container.append(layer_artist)

return viewer
2 changes: 1 addition & 1 deletion glue/viewers/histogram_new/qt/layer_style_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, layer, parent=None):
directory=os.path.dirname(__file__))

# TODO: In future, should pass only state not layer?
self.layer_state = layer.layer_state
self.layer_state = layer.state

connect_kwargs = {'alpha': dict(value_range=(0, 1))}

Expand Down
Loading

0 comments on commit 9ae7228

Please sign in to comment.