From b0994bf81f18c3bf35d1ef20ba50d36b467bbed6 Mon Sep 17 00:00:00 2001 From: vexilligera Date: Wed, 14 Nov 2018 20:53:25 +0800 Subject: [PATCH] Update utils.py --- utils.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 3 deletions(-) diff --git a/utils.py b/utils.py index 4400972..328e0dc 100644 --- a/utils.py +++ b/utils.py @@ -5,13 +5,21 @@ import torch from PIL import Image from NodeServer import Draw +from graphviz import Digraph def tensor2Image(tensor, path='', norm=True): + t = tensor if 'numpy' not in str(type(tensor)): - tensor = tensor.cpu().detach().numpy() + t = tensor.cpu().detach().numpy() if norm: - tensor *= 255 - img = Image.fromarray(tensor).convert('RGB') + t = t * 255 + if len(tensor.shape) != 3: + img = Image.fromarray(t).convert('RGB') + else: + r = Image.fromarray(t[0]).convert('L') + g = Image.fromarray(t[1]).convert('L') + b = Image.fromarray(t[2]).convert('L') + img = Image.merge("RGB", (r, g, b)) if path == '': img.show() else: @@ -114,3 +122,75 @@ def nextBatch(self): images = np.expand_dims(np.array(images), 1) self.iteration += 1 return images, data, trajectories + +# neural network visualization and gradient debug +# https://gist.github.com/apaszke/f93a377244be9bfcb96d3547b9bc424d +def iter_graph(root, callback): + queue = [root] + seen = set() + while queue: + fn = queue.pop() + if fn in seen or type(fn) == type(None): + continue + seen.add(fn) + for next_fn, _ in fn.next_functions: + if next_fn is not None: + queue.append(next_fn) + callback(fn) + +def register_hooks(var): + fn_dict = {} + def hook_cb(fn): + def register_grad(grad_input, grad_output): + fn_dict[fn] = grad_input + fn.register_hook(register_grad) + iter_graph(var.grad_fn, hook_cb) + + def is_bad_grad(grad_output): + if type(grad_output) != type(None): + grad_output = grad_output.data + return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any() + return True + + def make_dot(): + node_attr = dict(style='filled', + shape='box', + align='left', + fontsize='12', + ranksep='0.1', + height='0.2') + dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) + + def size_to_str(size): + return '('+(', ').join(map(str, size))+')' + + def build_graph(fn): + if hasattr(fn, 'variable'): # if GradAccumulator + u = fn.variable + node_name = 'Variable\n ' + size_to_str(u.size()) + dot.node(str(id(u)), node_name, fillcolor='lightblue') + else: + assert fn in fn_dict, fn + fillcolor = 'white' + if any(is_bad_grad(gi) for gi in fn_dict[fn]): + fillcolor = 'red' + dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor) + for next_fn, _ in fn.next_functions: + if next_fn is not None: + next_id = id(getattr(next_fn, 'variable', next_fn)) + dot.edge(str(next_id), str(id(fn))) + iter_graph(var.grad_fn, build_graph) + + return dot + + return make_dot + +def saveData(data, n_steps, path, idx=0): + step_data = [] + for i in range(n_steps): + color_radius = data[i][0][idx].cpu().detach().numpy().tolist() + action = data[i][1][idx].cpu().detach().numpy().tolist() + step_data.append([color_radius] + action) + f = open(path, 'w') + f.write(str(step_data)) + f.close()