Skip to content

Commit

Permalink
batch one
Browse files Browse the repository at this point in the history
  • Loading branch information
heilrahc committed Jan 31, 2023
1 parent 99fed99 commit 0aa7f81
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions Pruners/dp_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ def dp_sgd_backward(params, loss, device, clip_norm, noise_factor):
if not isinstance(params, list):
params = [p for p in params]

with backpack(BatchGrad(), BatchL2Grad()):
loss.backward()
#with backpack(BatchGrad(), BatchL2Grad()):
loss.backward()

squared_param_norms = [p.batch_l2 for p in params] # first we get all the squared parameter norms...
squared_param_norms = [pt.norm(p.grad) for p in params] # first we get all the squared parameter norms...
global_norms = pt.sqrt(pt.sum(pt.stack(squared_param_norms), dim=0)) # ...then compute the global norms...
global_clips = pt.clamp_max(clip_norm / global_norms, 1.) # ...and finally get a vector of clipping factors

for idx, param in enumerate(params):
clipped_sample_grads = param.grad_batch * expand_vector(global_clips, param.grad_batch)
clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch
#clipped_sample_grads = param.grad_batch * expand_vector(global_clips, param.grad_batch)
clipped_sample_grads = param.grad * expand_vector(global_clips, param.grad)
#clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch
clipped_grad = clipped_sample_grads

noise_sdev = noise_factor * 2 * clip_norm # gaussian noise standard dev is computed (sensitivity is 2*clip)...
perturbed_grad = clipped_grad + pt.randn_like(clipped_grad, device=device) * noise_sdev # ...and applied
Expand All @@ -36,7 +38,7 @@ def dp_sgd_backward(params, loss, device, clip_norm, noise_factor):


def expand_vector(vec, tgt_tensor):
tgt_shape = [vec.shape[0]] + [1] * (len(tgt_tensor.shape) - 1)
tgt_shape = [1] + [1] * (len(tgt_tensor.shape) - 1)
return vec.view(*tgt_shape)


Expand Down

0 comments on commit 0aa7f81

Please sign in to comment.