Skip to content

Commit

Permalink
Merge pull request pytorch#733 from osalpekar/dist_autograd_update
Browse files Browse the repository at this point in the history
[Dist Autograd - API Change] Updated dist_autograd and dist_optim to be functional
  • Loading branch information
Jessica Lin authored Mar 23, 2020
2 parents 8a5b379 + 55506b5 commit 234bcff
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions distributed/rpc/rnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def get_next_batch():
for epoch in range(10):
# create distributed autograd context
for data, target in get_next_batch():
with dist_autograd.context():
with dist_autograd.context() as context_id:
hidden[0].detach_()
hidden[1].detach_()
output, hidden = model(data, hidden)
loss = criterion(output, target)
# run distributed backward pass
dist_autograd.backward([loss])
dist_autograd.backward(context_id, [loss])
# run distributed optimizer
opt.step()
opt.step(context_id)
# not necessary to zero grads as each iteration creates a different
# distributed autograd context which hosts different grads
print("Training epoch {}".format(epoch))
Expand Down

0 comments on commit 234bcff

Please sign in to comment.