From 2a5a29ac896f309592365cf5a44057f8173fcbab Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Mon, 19 Dec 2022 22:15:23 -0800 Subject: [PATCH] register modules in the same order as the original as EBC configs (#905) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/905 PTD contract() wants all parmas to be in the same order. Make embeddingbag.py adhere to this by registering modules in the same order as EBC configs Reviewed By: YLGH Differential Revision: D42158944 fbshipit-source-id: 34219ffdf1457c299d24a938b6de10b126f9dc05 --- torchrec/distributed/embeddingbag.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 41d135b4e..fcd5638db 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -423,8 +423,9 @@ def _initialize_torch_state(self) -> None: # noqa This provides consistency between this class and the EmbeddingBagCollection's nn.Module API calls (state_dict, named_modules, etc) """ - self.embedding_bags: nn.ModuleDict = nn.ModuleDict() + for table_name in self._table_names: + self.embedding_bags[table_name] = nn.Module() self._model_parallel_name_to_local_shards = OrderedDict() model_parallel_name_to_compute_kernel: Dict[str, str] = {} for ( @@ -462,18 +463,16 @@ def _initialize_torch_state(self) -> None: # noqa table_name, tbe_slice, ) in lookup.named_parameters_by_table(): - self.embedding_bags[table_name] = torch.nn.Module() self.embedding_bags[table_name].register_parameter("weight", tbe_slice) for table_name in self._model_parallel_name_to_local_shards.keys(): - if table_name not in self.embedding_bags: - # for shards that don't exist on this rank, register with empty tensor - self.embedding_bags[table_name] = torch.nn.Module() + # for shards that don't exist on this rank, register with empty tensor + if not hasattr(self.embedding_bags[table_name], "weight"): self.embedding_bags[table_name].register_parameter( "weight", nn.Parameter(torch.empty(0)) ) if ( model_parallel_name_to_compute_kernel[table_name] - == EmbeddingComputeKernel.FUSED.value + != EmbeddingComputeKernel.DENSE.value ): self.embedding_bags[table_name].weight._overlapped_optimizer = True