Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vexilligera authored Nov 14, 2018
1 parent c90ef53 commit b0994bf
Showing 1 changed file with 83 additions and 3 deletions.
86 changes: 83 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
import torch
from PIL import Image
from NodeServer import Draw
from graphviz import Digraph

def tensor2Image(tensor, path='', norm=True):
t = tensor
if 'numpy' not in str(type(tensor)):
tensor = tensor.cpu().detach().numpy()
t = tensor.cpu().detach().numpy()
if norm:
tensor *= 255
img = Image.fromarray(tensor).convert('RGB')
t = t * 255
if len(tensor.shape) != 3:
img = Image.fromarray(t).convert('RGB')
else:
r = Image.fromarray(t[0]).convert('L')
g = Image.fromarray(t[1]).convert('L')
b = Image.fromarray(t[2]).convert('L')
img = Image.merge("RGB", (r, g, b))
if path == '':
img.show()
else:
Expand Down Expand Up @@ -114,3 +122,75 @@ def nextBatch(self):
images = np.expand_dims(np.array(images), 1)
self.iteration += 1
return images, data, trajectories

# neural network visualization and gradient debug
# https://gist.github.com/apaszke/f93a377244be9bfcb96d3547b9bc424d
def iter_graph(root, callback):
queue = [root]
seen = set()
while queue:
fn = queue.pop()
if fn in seen or type(fn) == type(None):
continue
seen.add(fn)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
queue.append(next_fn)
callback(fn)

def register_hooks(var):
fn_dict = {}
def hook_cb(fn):
def register_grad(grad_input, grad_output):
fn_dict[fn] = grad_input
fn.register_hook(register_grad)
iter_graph(var.grad_fn, hook_cb)

def is_bad_grad(grad_output):
if type(grad_output) != type(None):
grad_output = grad_output.data
return grad_output.ne(grad_output).any() or grad_output.gt(1e6).any()
return True

def make_dot():
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

def size_to_str(size):
return '('+(', ').join(map(str, size))+')'

def build_graph(fn):
if hasattr(fn, 'variable'): # if GradAccumulator
u = fn.variable
node_name = 'Variable\n ' + size_to_str(u.size())
dot.node(str(id(u)), node_name, fillcolor='lightblue')
else:
assert fn in fn_dict, fn
fillcolor = 'white'
if any(is_bad_grad(gi) for gi in fn_dict[fn]):
fillcolor = 'red'
dot.node(str(id(fn)), str(type(fn).__name__), fillcolor=fillcolor)
for next_fn, _ in fn.next_functions:
if next_fn is not None:
next_id = id(getattr(next_fn, 'variable', next_fn))
dot.edge(str(next_id), str(id(fn)))
iter_graph(var.grad_fn, build_graph)

return dot

return make_dot

def saveData(data, n_steps, path, idx=0):
step_data = []
for i in range(n_steps):
color_radius = data[i][0][idx].cpu().detach().numpy().tolist()
action = data[i][1][idx].cpu().detach().numpy().tolist()
step_data.append([color_radius] + action)
f = open(path, 'w')
f.write(str(step_data))
f.close()

0 comments on commit b0994bf

Please sign in to comment.