Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

row-wise alltoall error when some embeddings use mean pooling and others use sum pooling #2808

Open
tiankongdeguiji opened this issue Mar 12, 2025 · 2 comments

Comments

@tiankongdeguiji
Copy link
Contributor

tiankongdeguiji commented Mar 12, 2025

There is an "alltoall" error when using row-wise sharding, where some embeddingbags utilize mean pooling while others use sum pooling. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=49941 --nnodes=1 --nproc-per-node=2 test_row_wise_pooling.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124.

test_row_wise_pooling.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
    )
    for i in range(large_table_cnt)
]
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM if i % 2 == 0 else torchrec.PoolingType.MEAN,
    )
    for i in range(small_table_cnt)
]


def gen_constraints(
    sharding_type: ShardingType = ShardingType.ROW_WISE,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device="meta")
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()

model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.ROW_WISE)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, sharders, dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    sharders=sharders,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])

batch_size = 64
lengths_large = torch.randint(0, 10, (batch_size * large_table_cnt,))
lengths_small = torch.randint(0, 10, (batch_size * small_table_cnt,))
kjt = KeyedJaggedTensor(
    keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
    + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
    values=torch.cat([
        torch.randint(0, 4096, (torch.sum(lengths_large),))
        , torch.randint(0, 1023, (torch.sum(lengths_small),))]
    ),
    lengths=torch.cat([lengths_large, lengths_small]),
).to(device=device)
losses = sharded_model.forward(kjt)
torch.sum(losses, dim=0).backward()
optimizer.step()

error info:

[rank0]: Traceback (most recent call last):
[rank0]:   File "test_row_wise_pooling.py", line 124, in <module>
[rank0]:     losses = sharded_model.forward(kjt)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 308, in forward
[rank0]:     return self._dmp_wrapped_module(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "test_row_wise_pooling.py", line 73, in forward
[rank0]:     emb = self.ebc(kjt)
[rank0]:           ^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 997, in forward
[rank0]:     dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]:     ret: W = self._wait_impl()
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/embedding_sharding.py", line 745, in _wait_impl
[rank0]:     tensors_awaitables.append(w.wait())
[rank0]:                               ^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/types.py", line 334, in wait
[rank0]:     ret: W = self._wait_impl()
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 530, in _wait_impl
[rank0]:     return KJTAllToAllTensorsAwaitable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torchrec/distributed/dist_data.py", line 398, in __init__
[rank0]:     awaitable = dist.all_to_all_single(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4388, in all_to_all_single
[rank0]:     work = group.alltoall_base(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Split sizes doesn't match total dim 0 size
@tiankongdeguiji
Copy link
Contributor Author

hi, @iamzainhuda @henrylhtsang @PaulZhang12 @joshuadeng could you take a look?

KJT.lengths is modified by mean pooling callback. When some embeddings use mean pooling and others use sum pooling, I think KJT.lengths will be incorrect. we fix it in #2809.

@iamzainhuda
Copy link
Contributor

hey, thanks for this! taking a look. i haven't seen the need for a mix of sum and mean previously but your case is a very interesting one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants