Skip to content

Commit

Permalink
Add variable batch per feature support to EBC (tw/cw) (pytorch#1986)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#1358

Pull Request resolved: pytorch#1986

To enable provide stride_per_key_per_rank and inverse_indices to the KJT

Reindexing is done internally with EBC and with user provided inverse indices

Reviewed By: bigning

Differential Revision: D48805440

fbshipit-source-id: e490baa41b9c91b02e6c246940d4a3a2c4551d65
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Nov 20, 2023
1 parent f65d7e2 commit 54340d4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 24 deletions.
34 changes: 11 additions & 23 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import List, Optional

import torch
from torch import nn

try:
# pyre-ignore[21]
Expand All @@ -24,48 +23,37 @@
)


class PermutePooledEmbeddings(nn.Module):
class PermutePooledEmbeddings:
def __init__(
self,
embs_dims: List[int],
permute: List[int],
device: Optional[torch.device] = None,
) -> None:
super(PermutePooledEmbeddings, self).__init__()
logging.info("Using Permute Pooled Embeddings")

self.register_buffer(
"_offset_dim_list",
torch.tensor(
[0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64
),
self._offset_dim_list: torch.Tensor = torch.tensor(
[0] + list(accumulate(embs_dims)), device=device, dtype=torch.int64
)
self.register_buffer(
"_permute", torch.tensor(permute, device=device, dtype=torch.int64)

self._permute: torch.Tensor = torch.tensor(
permute, device=device, dtype=torch.int64
)

inv_permute: List[int] = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i

self.register_buffer(
"_inv_permute", torch.tensor(inv_permute, device=device, dtype=torch.int64)
self._inv_permute: torch.Tensor = torch.tensor(
inv_permute, device=device, dtype=torch.int64
)

# `Union[BoundMethod[typing.Callable(torch.Tensor.tolist)[[Named(self,
# torch.Tensor)], List[typing.Any]], torch.Tensor], nn.Module, torch.Tensor]`
# is not a function.

inv_embs_dims = [embs_dims[i] for i in permute]

self.register_buffer(
"_inv_offset_dim_list",
torch.tensor(
[0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64
),
self._inv_offset_dim_list: torch.Tensor = torch.tensor(
[0] + list(accumulate(inv_embs_dims)), device=device, dtype=torch.int64
)

def forward(self, pooled_embs: torch.Tensor) -> torch.Tensor:
def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor:
result = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
pooled_embs,
self._offset_dim_list.to(device=pooled_embs.device),
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/permute_pooled_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
FIXED_EXTERN_API = {
"PermutePooledEmbeddings": {
"__init__": ["self", "embs_dims", "permute", "device"],
"forward": ["self", "pooled_embs"],
"__call__": ["self", "pooled_embs"],
},
}

Expand Down

0 comments on commit 54340d4

Please sign in to comment.