Skip to content

Commit

Permalink
in group_table check string value of pooling mode (pytorch#1124)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1124

If the model is from torch package, checking pooling mode directly wouldn't result in a match because one would from <torch_package_x>.

Instead, we can check the string value.

Reviewed By: houseroad

Differential Revision: D45059135

fbshipit-source-id: 4b452bbbecc1c6b561ebb21906b017b3540205cd
  • Loading branch information
842974287 authored and facebook-github-bot committed Apr 18, 2023
1 parent 17248cb commit 13033eb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _group_tables_per_rank(
compute_kernel_type = EmbeddingComputeKernel.QUANT
if (
table.data_type == data_type
and table.pooling == pooling
and table.pooling.value == pooling.value
and table.has_feature_processor
== has_feature_processor
and compute_kernel_type == compute_kernel
Expand Down

0 comments on commit 13033eb

Please sign in to comment.