Skip to content

Commit

Permalink
Allow for custom sharding specs to register their own ops.
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#76360

Customized ShardingSpecs could be entirely arbitrary and it would not
be possible to handle ops for those as a result since they might not fit into
the patterns supported by the in-built ShardingSpecs. As a result, we introduce
a framework for a ShardingSpec to override ops as follows:

1) In the dispatch system, if a ShardingSpec has a customized op the registered
op for that ShardingSpec is invoked.
2) As a result, all ChunkShardingSpec specific ops have been moved under that
ShardingSpec.
3) There will be a set of ShardingSpec agnostic ops (ex: elementwise ops) which
will be a set of common ops supported across any ShardingSpec.
4) If an op is not found for a particular ShardingSpec the default set of ops
is searched for that op.

Differential Revision: [D35917912](https://our.internmc.facebook.com/intern/diff/D35917912/)

Approved by: https://github.com/wanchaol
  • Loading branch information
pritamdamania87 authored and pytorchmergebot committed May 6, 2022
1 parent fd6991e commit cc685bc
Show file tree
Hide file tree
Showing 17 changed files with 1,019 additions and 857 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_math_ops_errors(self):

st = sharded_tensor.rand(spec, 10, 10)

with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'):
with self.assertRaisesRegex(RuntimeError, 'not supported'):
torch.add(st, sharded_rhs)


Expand Down
11 changes: 7 additions & 4 deletions torch/distributed/_shard/sharded_tensor/_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch.distributed._shard.sharded_tensor._ops.chunk
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
import torch.distributed._shard.sharded_tensor._ops.math_ops
import torch.distributed._shard.sharded_tensor._ops.matrix_ops

from .binary_cmp import equal, allclose
from .embedding import sharded_embedding
from .embedding_bag import sharded_embedding_bag
from .init import kaiming_uniform_, normal_, uniform_, constant_
from .linear import sharded_linear

# Import all ChunkShardingSpec ops
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops
Loading

0 comments on commit cc685bc

Please sign in to comment.