Skip to content

Commit

Permalink
Remove call to .contiguous() for local_shard_t.
Browse files Browse the repository at this point in the history
The call to contiguous was probably left over from a previous
implementation and is no longer needed.

Had to adjust atol for one of the tests to accomodate for this.

Differential Revision: [D36797942](https://our.internmc.facebook.com/intern/diff/D36797942/)

Pull Request resolved: pytorch#78598

Approved by: https://github.com/kumpera
  • Loading branch information
pritamdamania87 authored and pytorchmergebot committed Jun 1, 2022
1 parent 497ae27 commit 5aa2ed1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _shard_parameter(module, spec):
)

# Test backward gradient calculation.
self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1, atol=1e-4, rtol=1e-6)
self.assertEqual(sharded_weight_fc1.grad, local_grad_narrowed_fc1, atol=1e-3, rtol=1e-6)
self.assertEqual(sharded_weight_fc2.grad, local_grad_narrowed_fc2, atol=1e-4, rtol=1e-6)
self.assertEqual(bias_grad_fc1, local_bias_grad_fc1, atol=1e-4, rtol=1e-6)
self.assertEqual(bias_grad_fc2, local_bias_grad_fc2, atol=1e-4, rtol=1e-6)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def sharded_linear(types, args, kwargs, pg):
bias = args[2]

local_shard = weight.local_tensor()
local_shard_t = local_shard.t().contiguous()
local_shard_t = local_shard.t()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
Expand Down

0 comments on commit 5aa2ed1

Please sign in to comment.