Skip to content

Commit

Permalink
Move instance of params from TBE fused kernels to named_parameters (p…
Browse files Browse the repository at this point in the history
…ytorch#893)

Summary:
Pull Request resolved: pytorch#893

Fixing long standing nit bug with TorchRec's named parameters not returning fused parameters.

Model code needs to change to account for this, we do this by adding the filter_optimizer_in_backward_named_parameters

This sets the stage for the composability change

The important parts to change, being MVAI/O3/PyPer should have this change. Tried to get everything else,

Most important files to look at IMO are

* full_sync_optimizer.py
* fbcode/torchrec/optim/optimizers.py
* torchrec/fb/optim/module_optimizer.py

Reviewed By: YLGH

Differential Revision: D41964643

fbshipit-source-id: 28f82a56a2147d10389ad4fd0e8dbaef78f12e3c
  • Loading branch information
colin2328 authored and facebook-github-bot committed Dec 16, 2022
1 parent d9cc3e0 commit e8ab2de
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,20 @@ def named_buffers(
By convention, fused parameters are designated as buffers because they no longer
have gradients available to external optimizers.
"""
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
# TODO can delete this override once SEA is removed
yield from ()

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()
for name, tensor in self.named_split_embedding_weights(
prefix, recurse, remove_duplicate
):
# hack before we support optimizer on sharded parameter level
param = nn.Parameter(tensor)
# pyre-ignore
param._overlapped_optimizer = True
yield name, param

def flush(self) -> None:
self._emb_module.flush()
Expand Down Expand Up @@ -570,12 +578,12 @@ def named_split_embedding_weights(
assert (
remove_duplicate
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
for config, param in zip(
for config, tensor in zip(
self._config.embedding_tables,
self.emb_module.split_embedding_weights(),
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param
yield key, tensor


class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule):
Expand Down Expand Up @@ -645,12 +653,20 @@ def named_buffers(
By convention, fused parameters are designated as buffers because they no longer
have gradients available to external optimizers.
"""
return self.named_split_embedding_weights(prefix, recurse, remove_duplicate)
# TODO can delete this override once SEA is removed
yield from ()

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()
for name, tensor in self.named_split_embedding_weights(
prefix, recurse, remove_duplicate
):
# hack before we support optimizer on sharded parameter level
param = nn.Parameter(tensor)
# pyre-ignore
param._overlapped_optimizer = True
yield name, param

def flush(self) -> None:
self._emb_module.flush()
Expand Down

0 comments on commit e8ab2de

Please sign in to comment.