Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The static_dict_gather function encounters precision issue when use mc-ebc #2827

Open
tiankongdeguiji opened this issue Mar 17, 2025 · 1 comment

Comments

@tiankongdeguiji
Copy link
Contributor

tiankongdeguiji commented Mar 17, 2025

The static_dict_gather function encounters precision issue when use mc-ebc. We can reproduce this using the following command: torchrun --master_addr=localhost --master_port=49941 --nnodes=1 --nproc-per-node=2 test_mc_ebc_export.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124.

test_mc_ebc_export.py

import os
from typing import Dict, cast

import torch
import torch.distributed as dist
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingType
from torchrec.optim import optimizers
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
    LFU_EvictionPolicy,
    ManagedCollisionCollection,
    MCHManagedCollisionModule,
)
from torchrec.inference.state_dict_transform import state_dict_gather, state_dict_to_device
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.distributed.sharding_plan import get_default_sharders

rank = int(os.environ["RANK"])
if torch.cuda.is_available():
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
    backend = "gloo"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(large_table_cnt)
]
large_mc_modules = {
    t.name: MCHManagedCollisionModule(
        zch_size=4096,
        device=device, #'meta',
        eviction_interval=10,
        eviction_policy=LFU_EvictionPolicy()
    ) 
    for t in large_tables
}
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=64,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(small_table_cnt)
]
small_mc_modules = {
    t.name: MCHManagedCollisionModule(
        zch_size=64,
        device=device, #'meta',
        eviction_interval=10,
        eviction_policy=LFU_EvictionPolicy()
    ) 
    for t in small_tables
}


def gen_constraints(
    sharding_type: ShardingType = ShardingType.ROW_WISE,
) -> Dict[str, ParameterConstraints]:
    large_table_constraints = {
        "large_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(large_table_cnt)
    }
    small_table_constraints = {
        "small_table_" + str(i): ParameterConstraints(
            sharding_types=[sharding_type.value],
        )
        for i in range(small_table_cnt)
    }
    constraints = {**large_table_constraints, **small_table_constraints}
    return constraints


class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ebc = ManagedCollisionEmbeddingBagCollection(
            EmbeddingBagCollection(tables=large_tables + small_tables, device="meta"),
            ManagedCollisionCollection(
                dict(large_mc_modules, **small_mc_modules),
                large_tables + small_tables,
            )
        )
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb, _ = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))


model = DebugModel()
apply_optimizer_in_backward(optimizers.Adagrad, model.ebc.parameters(), {"lr": 0.001})

topology = Topology(world_size=world_size, compute_device=device.type)
constraints = gen_constraints(ShardingType.ROW_WISE)
planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
)
# sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan = planner.collective_plan(model, get_default_sharders(), dist.GroupMember.WORLD)

sharded_model = DistributedModelParallel(
    model,
    plan=plan,
    device=device,
)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(sharded_model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([sharded_model.fused_optimizer, dense_optimizer])

pipeline = TrainPipelineSparseDist(
    sharded_model, optimizer, sharded_model.device, execute_all_batches=True
)


for _ in range(20):
    batch_size = 64
    lengths_large = torch.randint(0, 10, (batch_size * large_table_cnt,))
    lengths_small = torch.randint(0, 10, (batch_size * small_table_cnt,))
    kjt = KeyedJaggedTensor(
        keys=["large_table_feature_" + str(i) for i in range(large_table_cnt)]
        + ["small_table_feature_" + str(i) for i in range(small_table_cnt)],
        values=torch.cat([
            torch.randint(0, 4096, (torch.sum(lengths_large),))
            , torch.randint(0, 64, (torch.sum(lengths_small),))]
        ),
        lengths=torch.cat([lengths_large, lengths_small]),
    ).to(device=device)
    losses = sharded_model.forward(kjt)
    torch.sum(losses, dim=0).backward()
    optimizer.step()

checkpoint_pg = dist.new_group(backend="gloo")
cpu_state_dict = state_dict_to_device(
    sharded_model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu")
)
cpu_model = DebugModel().to_empty(device="cpu")
state_dict_gather(cpu_state_dict, cpu_model.state_dict())

if rank == 0:
    # print(sharded_model.state_dict().keys())
    print("sharded_model:", list(sharded_model.state_dict()["ebc._managed_collision_collection._managed_collision_modules.small_table_0._mch_sorted_raw_ids"].local_tensor().detach().cpu().numpy()))
    print("cpu_model:", list(cpu_model.state_dict()["ebc._managed_collision_collection._managed_collision_modules.small_table_0._mch_sorted_raw_ids"].detach().numpy()))

It prints

sharded_model: [0, 2, 3, 9, 10, 11, 13, 14, 18, 19, 23, 25, 29, 30, 31, 33, 34, 37, 38, 39, 43, 44, 50, 53, 54, 56, 57, 59, 61, 62, 63, 9223372036854775807]
cpu_model: [0, 2, 3, 9, 10, 11, 13, 14, 18, 19, 23, 25, 29, 30, 31, 33, 34, 37, 38, 39, 43, 44, 50, 53, 54, 56, 57, 59, 61, 62, 63, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808]

It shows 9223372036854775807 changed to -9223372036854775808 after state_dict_gather.

@tiankongdeguiji
Copy link
Contributor Author

hi, @iamzainhuda @henrylhtsang @PaulZhang12 @joshuadeng could you take a look?

It appears that static_dict_gather utilizes the float32 data type for all tensors, whereas mc-ebc employs int64 tensors. we fix it in #2826

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant