Skip to content

Commit

Permalink
register modules in the same order as the original as EBC configs (py…
Browse files Browse the repository at this point in the history
…torch#905)

Summary:
Pull Request resolved: pytorch#905

PTD contract() wants all parmas to be in the same order.
Make embeddingbag.py adhere to this by registering modules in the same order as EBC configs

Reviewed By: YLGH

Differential Revision: D42158944

fbshipit-source-id: 34219ffdf1457c299d24a938b6de10b126f9dc05
  • Loading branch information
colin2328 authored and facebook-github-bot committed Dec 20, 2022
1 parent bf5a1cc commit 2a5a29a
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,9 @@ def _initialize_torch_state(self) -> None: # noqa
This provides consistency between this class and the EmbeddingBagCollection's
nn.Module API calls (state_dict, named_modules, etc)
"""

self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
for table_name in self._table_names:
self.embedding_bags[table_name] = nn.Module()
self._model_parallel_name_to_local_shards = OrderedDict()
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
for (
Expand Down Expand Up @@ -462,18 +463,16 @@ def _initialize_torch_state(self) -> None: # noqa
table_name,
tbe_slice,
) in lookup.named_parameters_by_table():
self.embedding_bags[table_name] = torch.nn.Module()
self.embedding_bags[table_name].register_parameter("weight", tbe_slice)
for table_name in self._model_parallel_name_to_local_shards.keys():
if table_name not in self.embedding_bags:
# for shards that don't exist on this rank, register with empty tensor
self.embedding_bags[table_name] = torch.nn.Module()
# for shards that don't exist on this rank, register with empty tensor
if not hasattr(self.embedding_bags[table_name], "weight"):
self.embedding_bags[table_name].register_parameter(
"weight", nn.Parameter(torch.empty(0))
)
if (
model_parallel_name_to_compute_kernel[table_name]
== EmbeddingComputeKernel.FUSED.value
!= EmbeddingComputeKernel.DENSE.value
):
self.embedding_bags[table_name].weight._overlapped_optimizer = True

Expand Down

0 comments on commit 2a5a29a

Please sign in to comment.