Skip to content

Commit

Permalink
Merge branch 'hessian'
Browse files Browse the repository at this point in the history
  • Loading branch information
ljk628 committed Nov 7, 2018
2 parents aa6bac7 + 05c9824 commit 97a5d99
Show file tree
Hide file tree
Showing 4 changed files with 459 additions and 0 deletions.
148 changes: 148 additions & 0 deletions hess_vec_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
import time
import numpy as np
from torch import nn
from torch.autograd import Variable
from scipy.sparse.linalg import LinearOperator, eigsh

################################################################################
# Supporting Functions
################################################################################
def npvec_to_tensorlist(vec, params):
""" Convert a numpy vector to a list of tensor with the same dimensions as params
Args:
vec: a 1D numpy vector
params: a list of parameters from net
Returns:
rval: a list of tensors with the same shape as params
"""
loc = 0
rval = []
for p in params:
numel = p.data.numel()
rval.append(torch.from_numpy(vec[loc:loc+numel]).view(p.data.shape).float())
loc += numel
assert loc == vec.size, 'The vector has more elements than the net has parameters'
return rval


def gradtensor_to_npvec(net, include_bn=False):
""" Extract gradients from net, and return a concatenated numpy vector.
Args:
net: trained model
include_bn: If include_bn, then gradients w.r.t. BN parameters and bias
values are also included. Otherwise only gradients with dim > 1 are considered.
Returns:
a concatenated numpy vector containing all gradients
"""
filter = lambda p: include_bn or len(p.data.size()) > 1
return np.concatenate([p.grad.data.cpu().numpy().ravel() for p in net.parameters() if filter(p)])


################################################################################
# For computing Hessian-vector products
################################################################################
def eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda=False):
"""
Evaluate product of the Hessian of the loss function with a direction vector "vec".
The product result is saved in the grad of net.
Args:
vec: a list of tensor with the same dimensions as "params".
params: the parameter list of the net (ignoring biases and BN parameters).
net: model with trained parameters.
criterion: loss function.
dataloader: dataloader for the dataset.
use_cuda: use GPU.
"""

if use_cuda:
net.cuda()
vec = [v.cuda() for v in vec]

net.eval()
net.zero_grad() # clears grad for every parameter in the net

for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs, targets = Variable(inputs), Variable(targets)
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()

outputs = net(inputs)
loss = criterion(outputs, targets)
grad_f = torch.autograd.grad(loss, inputs=params, create_graph=True)

# Compute inner product of gradient with the direction vector
prod = Variable(torch.zeros(1)).type(type(grad_f[0].data))
for (g, v) in zip(grad_f, vec):
prod = prod + (g * v).cpu().sum()

# Compute the Hessian-vector product, H*v
# prod.backward() computes dprod/dparams for every parameter in params and
# accumulate the gradients into the params.grad attributes
prod.backward()


################################################################################
# For computing Eigenvalues of Hessian
################################################################################
def min_max_hessian_eigs(net, dataloader, criterion, rank=0, use_cuda=False, verbose=False):
"""
Compute the largest and the smallest eigenvalues of the Hessian marix.
Args:
net: the trained model.
dataloader: dataloader for the dataset, may use a subset of it.
criterion: loss function.
rank: rank of the working node.
use_cuda: use GPU
verbose: print more information
Returns:
maxeig: max eigenvalue
mineig: min eigenvalue
hess_vec_prod.count: number of iterations for calculating max and min eigenvalues
"""

params = [p for p in net.parameters() if len(p.size()) > 1]
N = sum(p.numel() for p in params)

def hess_vec_prod(vec):
hess_vec_prod.count += 1 # simulates a static variable
vec = npvec_to_tensorlist(vec, params)
start_time = time.time()
eval_hess_vec_prod(vec, params, net, criterion, dataloader, use_cuda)
prod_time = time.time() - start_time
if verbose and rank == 0: print(" Iter: %d time: %f" % (hess_vec_prod.count, prod_time))
return gradtensor_to_npvec(net)

hess_vec_prod.count = 0
if verbose and rank == 0: print("Rank %d: computing max eigenvalue" % rank)

A = LinearOperator((N, N), matvec=hess_vec_prod)
eigvals, eigvecs = eigsh(A, k=1, tol=1e-2)
maxeig = eigvals[0]
if verbose and rank == 0: print('max eigenvalue = %f' % maxeig)

# If the largest eigenvalue is positive, shift matrix so that any negative eigenvalue is now the largest
# We assume the smallest eigenvalue is zero or less, and so this shift is more than what we need
shift = maxeig*.51
def shifted_hess_vec_prod(vec):
return hess_vec_prod(vec) - shift*vec

if verbose and rank == 0: print("Rank %d: Computing shifted eigenvalue" % rank)

A = LinearOperator((N, N), matvec=shifted_hess_vec_prod)
eigvals, eigvecs = eigsh(A, k=1, tol=1e-2)
eigvals = eigvals + shift
mineig = eigvals[0]
if verbose and rank == 0: print('min eigenvalue = ' + str(mineig))

if maxeig <= 0 and mineig > 0:
maxeig, mineig = mineig, maxeig

return maxeig, mineig, hess_vec_prod.count
28 changes: 28 additions & 0 deletions plot_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def plot_1d_loss_err_repeat(prefix, idx_min=1, idx_max=10, xmin=-1.0, xmax=1.0,
if show: pp.show()


def plot_1d_eig_ratio(surf_file, xmin=-1.0, xmax=1.0, val_1='min_eig', val_2='max_eig', ymax=1, show=False):
print('------------------------------------------------------------------')
print('plot_1d_eig_ratio')
print('------------------------------------------------------------------')

f = h5py.File(surf_file,'r')
x = f['xcoordinates'][:]

Z1 = np.array(f[val_1][:])
Z2 = np.array(f[val_2][:])
abs_ratio = np.absolute(np.divide(Z1, Z2))

pp.plot(x, abs_ratio)
pp.xlim(xmin, xmax)
pp.ylim(0, ymax)
pp.savefig(surf_file + '_1d_eig_abs_ratio.pdf', dpi=300, bbox_inches='tight', format='pdf')

ratio = np.divide(Z1, Z2)
pp.plot(x, ratio)
pp.xlim(xmin, xmax)
pp.ylim(0, ymax)
pp.savefig(surf_file + '_1d_eig_ratio.pdf', dpi=300, bbox_inches='tight', format='pdf')

f.close()
if show: pp.show()



if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plott 1D loss and error curves')
parser.add_argument('--surf_file', '-f', default='', help='The h5 file contains loss values')
Expand Down
38 changes: 38 additions & 0 deletions plot_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,44 @@ def plot_contour_trajectory(surf_file, dir_file, proj_file, surf_name='loss_vals
if show: plt.show()


def plot_2d_eig_ratio(surf_file, val_1='min_eig', val_2='max_eig', show=False):
""" Plot the heatmap of eigenvalue ratios, i.e., |min_eig/max_eig| of hessian """

print('------------------------------------------------------------------')
print('plot_2d_eig_ratio')
print('------------------------------------------------------------------')
print("loading surface file: " + surf_file)
f = h5py.File(surf_file,'r')
x = np.array(f['xcoordinates'][:])
y = np.array(f['ycoordinates'][:])
X, Y = np.meshgrid(x, y)

Z1 = np.array(f[val_1][:])
Z2 = np.array(f[val_2][:])

# Plot 2D heatmaps with color bar using seaborn
abs_ratio = np.absolute(np.divide(Z1, Z2))
print(abs_ratio)

fig = plt.figure()
sns_plot = sns.heatmap(abs_ratio, cmap='viridis', vmin=0, vmax=.5, cbar=True,
xticklabels=False, yticklabels=False)
sns_plot.invert_yaxis()
sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_abs_ratio_heat_sns.pdf',
dpi=300, bbox_inches='tight', format='pdf')

# Plot 2D heatmaps with color bar using seaborn
ratio = np.divide(Z1, Z2)
print(ratio)
fig = plt.figure()
sns_plot = sns.heatmap(ratio, cmap='viridis', cbar=True, xticklabels=False, yticklabels=False)
sns_plot.invert_yaxis()
sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_ratio_heat_sns.pdf',
dpi=300, bbox_inches='tight', format='pdf')
f.close()
if show: plt.show()



if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plot 2D loss surface')
Expand Down
Loading

0 comments on commit 97a5d99

Please sign in to comment.