Skip to content

Commit

Permalink
lib.gui.stats - Read loss names from model config output rather than …
Browse files Browse the repository at this point in the history
…state file
  • Loading branch information
torzdf committed May 23, 2021
1 parent 3d914ee commit 6ee896d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 14 deletions.
56 changes: 45 additions & 11 deletions lib/gui/analysis/event_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,30 +589,64 @@ def _parse_outputs(self, event):
Loss names are added to :attr:`_loss_labels`
Notes
-----
The master model does not actually contain the specified output name, so we dig into the
sub-model to obtain the name of the output layers
Parameters
----------
event: :class:`tensorflow.core.util.event_pb2`
The event data containing the keras model structure to be parsed
"""
serializer = get_serializer("json")
struct = event.summary.value[0].tensor.string_val[0]
outputs = np.array(serializer.unmarshal(struct)["config"]["output_layers"])
logger.debug("Obtained model outputs: %s, shape: %s", outputs, outputs.shape)
if outputs.ndim == 2: # Insert extra dimension for non learn mask models
outputs = np.expand_dims(outputs, axis=1)
logger.debug("Expanded dimensions for non-learn_mask model. outputs: %s, shape: %s",
outputs, outputs.shape)
for side_outputs, side in zip(outputs, ("a", "b")):

config = serializer.unmarshal(struct)["config"]
model_outputs = self._get_outputs(config)
split_output = len(np.unique(model_outputs[..., 1])) == 1

for side_outputs, side in zip(model_outputs, ("a", "b")):
logger.debug("side: '%s', outputs: '%s'", side, side_outputs)
for idx in range(len(side_outputs)):
# First output is always face. Subsequent outputs are masks
loss_name = f"face_{side}" if idx == 0 else f"mask_{side}"
loss_name = loss_name if idx < 2 else f"{loss_name}_{idx}"
layer_name = side_outputs[0][0]

output_config = next(layer for layer in config["layers"]
if layer["name"] == layer_name)["config"]
layer_outputs = self._get_outputs(output_config)
for output in layer_outputs: # Drill into sub-model to get the actual output names
loss_name = output[0][0]
if not split_output: # Rename losses to reflect the side's output
loss_name = f"{loss_name.replace('_both', '')}_{side}"
if loss_name not in self._loss_labels:
logger.debug("Adding loss name: '%s'", loss_name)
self._loss_labels.append(loss_name)
logger.debug("Collated loss labels: %s", self._loss_labels)

@classmethod
def _get_outputs(cls, model_config):
""" Obtain the output names, instance index and output index for the given model.
If there is only a single output, the shape of the array is expanded to remain consistent
with multi model outputs
Parameters
----------
model_config: dict
The saved Keras model configuration dictionary
Returns
-------
:class:`numpy.ndarray`
The layer output names, their instance index and their output index
"""
outputs = np.array(model_config["output_layers"])
logger.debug("Obtained model outputs: %s, shape: %s", outputs, outputs.shape)
if outputs.ndim == 2: # Insert extra dimension for non learn mask models
outputs = np.expand_dims(outputs, axis=1)
logger.debug("Expanded dimensions for single output model. outputs: %s, shape: %s",
outputs, outputs.shape)
return outputs

@classmethod
def _process_event(cls, event, step):
""" Process a single Tensorflow event.
Expand Down
8 changes: 5 additions & 3 deletions lib/gui/analysis/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,14 @@ def get_loss_keys(self, session_id):
The loss keys for the given session. If ``None`` is passed as session_id then a unique
list of all loss keys for all sessions is returned
"""
loss_keys = {sess_id: list(logs.keys())
for sess_id, logs in self._tb_logs.get_loss(session_id=session_id).items()}
if session_id is None:
retval = list(set(loss_key
for session in self._state["sessions"].values()
for loss_key in session["loss_names"]))
for session in loss_keys.values()
for loss_key in session))
else:
retval = self._state["sessions"][str(session_id)]["loss_names"]
retval = loss_keys[session_id]
return retval


Expand Down

0 comments on commit 6ee896d

Please sign in to comment.