Skip to content

Commit

Permalink
refactor net_plotter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ljk628 committed Sep 18, 2018
1 parent d5519ea commit 733137f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
78 changes: 39 additions & 39 deletions net_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def set_weights(net, weights, directions=None, step=None):
for (p, w) in zip(net.parameters(), weights):
p.data.copy_(w.type(type(p.data)))
else:
assert step is not None, \
'If a direction is specified then the step size must be specified as well'
assert step is not None, 'If a direction is specified then step must be specified as well'

if len(directions) == 2:
dx = directions[0]
Expand All @@ -44,13 +43,12 @@ def set_weights(net, weights, directions=None, step=None):

def set_states(net, states, directions=None, step=None):
"""
Overwrite the network's state_dict or change it along directions
with a step size.
Overwrite the network's state_dict or change it along directions with a step size.
"""
if directions is None:
net.load_state_dict(states)
else:
assert step is not None, 'If direction is specified then step must be specified as well'
assert step is not None, 'If direction is provided then the step must be specified as well'
if len(directions) == 2:
dx = directions[0]
dy = directions[1]
Expand All @@ -59,9 +57,11 @@ def set_states(net, states, directions=None, step=None):
changes = [d*step for d in directions[0]]

new_states = copy.deepcopy(states)
assert (len(new_states) == len(changes))
for (k, v), d in zip(new_states.items(), changes):
d = torch.tensor(d)
new_states[k].add_(d.type(v.type()))
v.add_(d.type(v.type()))

net.load_state_dict(new_states)


Expand Down Expand Up @@ -131,115 +131,116 @@ def normalize_direction(direction, weights, norm='filter'):

def normalize_directions_for_weights(direction, weights, norm='filter', ignore='biasbn'):
"""
The normalization scales the direction entries according to the
entries of weights.
The normalization scales the direction entries according to the entries of weights.
"""
assert(len(direction) == len(weights))
for d, w in zip(direction, weights):
if d.dim() <= 1:
if ignore == 'biasbn':
d.fill_(0) # ignore directions for weights with 1 dimension
else:
# keep directions for weights/bias that are only 1 per node
d.copy_(w)
d.copy_(w) # keep directions for weights/bias that are only 1 per node
else:
normalize_direction(d, w, norm)


def normalize_directions_for_states(direction, states, norm='filter', ignore='ignore'):
assert(len(direction) == len(states.items()))
assert(len(direction) == len(states))
for d, (k, w) in zip(direction, states.items()):
if d.dim() <= 1:
if ignore == 'biasbn':
d.fill_(0) # ignore directions for weights with 1 dimension
else:
# keep directions for weights/bias that are only 1 per node
d.copy_(w)
d.copy_(w) # keep directions for weights/bias that are only 1 per node
else:
normalize_direction(d, w, norm)


################################################################################
# Create directions
################################################################################
def create_target_direction(w, s, net2, dir_type='states'):
def create_target_direction(net, net2, dir_type='states'):
"""
Setup a target direction from one model to the other
Args:
w: a list of parameters (variables).
s: a list of parameters (variables) including BN's running mean/var.
net2: the target model with the same architecture.
dir_type: "weights" or "states", type of directions.
net: the source model
net2: the target model with the same architecture as net.
dir_type: 'weights' or 'states', type of directions.
Returns:
the target direction with the same dimension as weights or states.
direction: the target direction from net to net2 with the same dimension
as weights or states.
"""

assert (net2 is not None)
# direction between net2 and net
if dir_type == 'weights':
# direction between w2 and w
w = get_weights(net)
w2 = get_weights(net2)
direction = get_diff_weights(w, w2)
elif dir_type == 'states':
# direction between s2 and s, including BN's statistics (running mean/var)
s = net.state_dict()
s2 = net2.state_dict()
direction = get_diff_states(s, s2)

return direction


def create_random_direction(w, s, dir_type='weights', ignore='biasbn', norm='filter'):
def create_random_direction(net, dir_type='weights', ignore='biasbn', norm='filter'):
"""
Setup a random (normalized) direction with the same dimension as
the weights or states.
Args:
w: a list of parameters (variables).
s: a list of parameters (variables), including BN's running mean/var.
dir_type: "weights" or "states", type of directions.
ignore: "biasbn", ignore biases and BN parameters.
net: the given trained model
dir_type: 'weights' or 'states', type of directions.
ignore: 'biasbn', ignore biases and BN parameters.
norm: direction normalization method, including
'filter" | 'layer' | 'weight' | 'dlayer' | 'dfilter'
Returns:
a random direction with the same dimension as weights or states.
direction: a random direction with the same dimension as weights or states.
"""

# random direction
if dir_type == 'weights':
direction = get_random_weights(w)
normalize_directions_for_weights(direction, w, norm, ignore)
weights = get_weights(net) # a list of parameters.
direction = get_random_weights(weights)
normalize_directions_for_weights(direction, weights, norm, ignore)
elif dir_type == 'states':
direction = get_random_states(s)
normalize_directions_for_states(direction, s, norm, ignore)
states = net.state_dict() # a dict of parameters, including BN's running mean/var.
direction = get_random_states(states)
normalize_directions_for_states(direction, states, norm, ignore)

return direction


def setup_direction(args, dir_file, w, s):
def setup_direction(args, dir_file, net):
"""
Setup the h5 file to store the directions.
- xdirection, ydirection: The pertubation direction added to the mdoel.
The direction is a list of tensors.
- xcoordinates, ycoordinates: the coorditnates at which to calculate values.
"""
# skip if the direction file already exists
print('-------------------------------------------------------------------')
print('setup_direction')
print('-------------------------------------------------------------------')
# Skip if the direction file already exists
if exists(dir_file):
f = h5py.File(dir_file, 'r')
if (args.y and 'ydirection' in f.keys()) or 'xdirection' in f.keys():
f.close()
print ("%s is already setted up" % dir_file)
return
f.close()

# Create the plotting directions
f = h5py.File(dir_file,'w-') # create file, fail if exists
f = h5py.File(dir_file,'w') # create file, fail if exists
if not args.dir_file:
print("Setting up the plotting directions...")
if args.model_file2:
net2 = model_loader.load(args.dataset, args.model, args.model_file2)
xdirection = create_target_direction(w, s, net2, args.dir_type)
xdirection = create_target_direction(net, net2, args.dir_type)
else:
xdirection = create_random_direction(w, s, args.dir_type, args.xignore, args.xnorm)
h5_util.write_list(f, 'xdirection', xdirection)
Expand All @@ -249,12 +250,13 @@ def setup_direction(args, dir_file, w, s):
ydirection = xdirection
elif args.model_file3:
net3 = model_loader.load(args.dataset, args.model, args.model_file3)
ydirection = create_target_direction(w, s, net3, args.dir_type)
ydirection = create_target_direction(net, net3, args.dir_type)
else:
ydirection = create_random_direction(w, s, args.dir_type, args.yignore, args.ynorm)
h5_util.write_list(f, 'ydirection', ydirection)

f.close()
print ("direction file created: %s" % dir_file)


def name_direction_file(args):
Expand Down Expand Up @@ -312,8 +314,6 @@ def name_direction_file(args):

dir_file += ".h5"

print ("direction file created: %s" % dir_file)

return dir_file


Expand Down
25 changes: 13 additions & 12 deletions plot_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,10 @@ def setup_surface_file(args, surf_file, dir_file):

# Create the coordinates(resolutions) at which the function is evaluated
xcoordinates = np.linspace(args.xmin, args.xmax, num=args.xnum)
shape = xcoordinates.shape
if args.y:
ycoordinates = np.linspace(args.ymin, args.ymax, num=args.ynum)
shape = (len(xcoordinates), len(ycoordinates))

f['xcoordinates'] = xcoordinates

if args.y:
ycoordinates = np.linspace(args.ymin, args.ymax, num=args.ynum)
f['ycoordinates'] = ycoordinates
f.close()

Expand Down Expand Up @@ -100,8 +97,7 @@ def crunch(surf_file, net, w, s, d, dataloader, loss_key, acc_key, comm, rank, a
# Generate a list of indices of 'losses' that need to be filled in.
# The coordinates of each unfilled index (with respect to the direction vectors
# stored in 'd') are stored in 'coords'.
inds, coords, inds_nums = scheduler.get_job_indices(losses, xcoordinates,
ycoordinates, comm)
inds, coords, inds_nums = scheduler.get_job_indices(losses, xcoordinates, ycoordinates, comm)

print('Computing %d values for rank %d'% (len(inds), rank))
start_time = time.time()
Expand Down Expand Up @@ -207,9 +203,11 @@ def crunch(surf_file, net, w, s, d, dataloader, loss_key, acc_key, comm, rank, a
parser.add_argument('--vlevel', default=0.5, type=float, help='plot contours every vlevel')
parser.add_argument('--show', action='store_true', default=False, help='show plotted figures')
parser.add_argument('--log', action='store_true', default=False, help='use log scale for loss values')
parser.add_argument('--plot', action='store_true', default=False, help='plot figures after computation')

args = parser.parse_args()

torch.manual_seed(123)
#--------------------------------------------------------------------------
# Environment setup
#--------------------------------------------------------------------------
Expand Down Expand Up @@ -245,19 +243,21 @@ def crunch(surf_file, net, w, s, d, dataloader, loss_key, acc_key, comm, rank, a
# Load models and extract parameters
#--------------------------------------------------------------------------
net = model_loader.load(args.dataset, args.model, args.model_file)
w = net_plotter.get_weights(net) # extract weights
w = net_plotter.get_weights(net) # initial parameters
s = copy.deepcopy(net.state_dict()) # deepcopy since state_dict are references
if args.ngpu > 1:
# data parallel with multiple GPUs on a single node
net = nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))

#--------------------------------------------------------------------------
# Setup the direction and surface file
# Setup the direction file and the surface file
#--------------------------------------------------------------------------
dir_file = net_plotter.name_direction_file(args) # name the direction file
if rank == 0:
net_plotter.setup_direction(args, dir_file, net)

surf_file = name_surface_file(args, dir_file)
if rank == 0:
net_plotter.setup_direction(args, dir_file, w, s)
setup_surface_file(args, surf_file, dir_file)

# wait until master has setup the direction file and surface file
Expand All @@ -277,6 +277,8 @@ def crunch(surf_file, net, w, s, d, dataloader, loss_key, acc_key, comm, rank, a
if rank == 0 and args.dataset == 'cifar10':
torchvision.datasets.CIFAR10(root=args.dataset + '/data', train=True, download=True)

mpi4pytorch.barrier(comm)

trainloader, testloader = dataloader.load_dataset(args.dataset, args.datapath,
args.batch_size, args.threads, args.raw_data,
args.data_split, args.split_idx,
Expand All @@ -285,14 +287,13 @@ def crunch(surf_file, net, w, s, d, dataloader, loss_key, acc_key, comm, rank, a
#--------------------------------------------------------------------------
# Start the computation
#--------------------------------------------------------------------------

crunch(surf_file, net, w, s, d, trainloader, 'train_loss', 'train_acc', comm, rank, args)
# crunch(surf_file, net, w, s, d, testloader, 'test_loss', 'test_acc', comm, rank, args)

#--------------------------------------------------------------------------
# Plot figures
#--------------------------------------------------------------------------
if rank == 0:
if args.plot and rank == 0:
if args.y and args.proj_file:
plot_2D.plot_contour_trajectory(surf_file, dir_file, args.proj_file, 'train_loss', args.show)
elif args.y:
Expand Down

0 comments on commit 733137f

Please sign in to comment.