Skip to content

Commit 1afbf08

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
re-land TD feature (#2696)
Summary: Pull Request resolved: #2696 # context * previous landing of D66521351 and D65103519 triggered multiple SEVs * after fixing the SEVs, we are giving it another try. Reviewed By: dstaay-fb Differential Revision: D68511145 fbshipit-source-id: 6092e64aadd0d88986d67cae368ec3909b949d6a
1 parent 26e0732 commit 1afbf08

File tree

6 files changed

+100
-16
lines changed

6 files changed

+100
-16
lines changed

torchrec/distributed/embedding.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
import torch
29+
from tensordict import TensorDict
2930
from torch import distributed as dist, nn
3031
from torch.autograd.profiler import record_function
3132
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
@@ -90,6 +91,7 @@
9091
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9192
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9293
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
94+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9395

9496
try:
9597
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -1198,8 +1200,15 @@ def _compute_sequence_vbe_context(
11981200
def input_dist(
11991201
self,
12001202
ctx: EmbeddingCollectionContext,
1201-
features: KeyedJaggedTensor,
1203+
features: TypeUnion[KeyedJaggedTensor, TensorDict],
12021204
) -> Awaitable[Awaitable[KJTList]]:
1205+
need_permute: bool = True
1206+
if isinstance(features, TensorDict):
1207+
feature_keys = list(features.keys()) # pyre-ignore[6]
1208+
if self._features_order:
1209+
feature_keys = [feature_keys[i] for i in self._features_order]
1210+
need_permute = False
1211+
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12031212
if self._has_uninitialized_input_dist:
12041213
self._create_input_dist(input_feature_names=features.keys())
12051214
self._has_uninitialized_input_dist = False
@@ -1209,7 +1218,7 @@ def input_dist(
12091218
unpadded_features = features
12101219
features = pad_vbe_kjt_lengths(unpadded_features)
12111220

1212-
if self._features_order:
1221+
if need_permute and self._features_order:
12131222
features = features.permute(
12141223
self._features_order,
12151224
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`

torchrec/distributed/embeddingbag.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30+
from tensordict import TensorDict
3031
from torch import distributed as dist, nn, Tensor
3132
from torch.autograd.profiler import record_function
3233
from torch.distributed._shard.sharded_tensor import TensorProperties
@@ -94,6 +95,7 @@
9495
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9596
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9697
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
98+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9799

98100
try:
99101
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -656,9 +658,7 @@ def __init__(
656658
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
657659
# to support mean pooling callback hook
658660
self._has_mean_pooling_callback: bool = (
659-
True
660-
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
661-
else False
661+
PoolingType.MEAN.value in self._pooling_type_to_rs_features
662662
)
663663
self._dim_per_key: Optional[torch.Tensor] = None
664664
self._kjt_key_indices: Dict[str, int] = {}
@@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices(
11891189

11901190
# pyre-ignore [14]
11911191
def input_dist(
1192-
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
1192+
self,
1193+
ctx: EmbeddingBagCollectionContext,
1194+
features: Union[KeyedJaggedTensor, TensorDict],
11931195
) -> Awaitable[Awaitable[KJTList]]:
1196+
if isinstance(features, TensorDict):
1197+
feature_keys = list(features.keys()) # pyre-ignore[6]
1198+
if len(self._features_order) > 0:
1199+
feature_keys = [feature_keys[i] for i in self._features_order]
1200+
self._has_features_permute = False # feature_keys are in order
1201+
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
11941202
ctx.variable_batch_per_feature = features.variable_stride_per_key()
11951203
ctx.inverse_indices = features.inverse_indices_or_none()
11961204
if self._has_uninitialized_input_dist:

torchrec/distributed/test_utils/test_sharding.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def gen_model_and_input(
147147
long_indices: bool = True,
148148
global_constant_batch: bool = False,
149149
num_inputs: int = 1,
150+
input_type: str = "kjt", # "kjt" or "td"
150151
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
151152
torch.manual_seed(0)
152153
if dedup_feature_names:
@@ -177,9 +178,9 @@ def gen_model_and_input(
177178
feature_processor_modules=feature_processor_modules,
178179
)
179180
inputs = []
180-
for _ in range(num_inputs):
181-
inputs.append(
182-
(
181+
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
182+
for _ in range(num_inputs):
183+
inputs.append(
183184
cast(VariableBatchModelInputCallable, generate)(
184185
average_batch_size=batch_size,
185186
world_size=world_size,
@@ -188,8 +189,26 @@ def gen_model_and_input(
188189
weighted_tables=weighted_tables or [],
189190
global_constant_batch=global_constant_batch,
190191
)
191-
if generate == ModelInput.generate_variable_batch_input
192-
else cast(ModelInputCallable, generate)(
192+
)
193+
elif generate == ModelInput.generate:
194+
for _ in range(num_inputs):
195+
inputs.append(
196+
ModelInput.generate(
197+
world_size=world_size,
198+
tables=tables,
199+
dedup_tables=dedup_tables,
200+
weighted_tables=weighted_tables or [],
201+
num_float_features=num_float_features,
202+
variable_batch_size=variable_batch_size,
203+
batch_size=batch_size,
204+
long_indices=long_indices,
205+
input_type=input_type,
206+
)
207+
)
208+
else:
209+
for _ in range(num_inputs):
210+
inputs.append(
211+
cast(ModelInputCallable, generate)(
193212
world_size=world_size,
194213
tables=tables,
195214
dedup_tables=dedup_tables,
@@ -200,7 +219,6 @@ def gen_model_and_input(
200219
long_indices=long_indices,
201220
)
202221
)
203-
)
204222
return (model, inputs)
205223

206224

@@ -297,6 +315,7 @@ def sharding_single_rank_test(
297315
global_constant_batch: bool = False,
298316
world_size_2D: Optional[int] = None,
299317
node_group_size: Optional[int] = None,
318+
input_type: str = "kjt", # "kjt" or "td"
300319
) -> None:
301320
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
302321
# Generate model & inputs.
@@ -319,6 +338,7 @@ def sharding_single_rank_test(
319338
batch_size=batch_size,
320339
feature_processor_modules=feature_processor_modules,
321340
global_constant_batch=global_constant_batch,
341+
input_type=input_type,
322342
)
323343
global_model = global_model.to(ctx.device)
324344
global_input = inputs[0][0].to(ctx.device)

torchrec/distributed/tests/test_sequence_model_parallel.py

+41
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,44 @@ def _test_sharding(
376376
variable_batch_per_feature=variable_batch_per_feature,
377377
global_constant_batch=True,
378378
)
379+
380+
381+
@skip_if_asan_class
382+
class TDSequenceModelParallelTest(SequenceModelParallelTest):
383+
384+
def test_sharding_variable_batch(self) -> None:
385+
pass
386+
387+
def _test_sharding(
388+
self,
389+
sharders: List[TestEmbeddingCollectionSharder],
390+
backend: str = "gloo",
391+
world_size: int = 2,
392+
local_size: Optional[int] = None,
393+
constraints: Optional[Dict[str, ParameterConstraints]] = None,
394+
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
395+
qcomms_config: Optional[QCommsConfig] = None,
396+
apply_optimizer_in_backward_config: Optional[
397+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
398+
] = None,
399+
variable_batch_size: bool = False,
400+
variable_batch_per_feature: bool = False,
401+
) -> None:
402+
self._run_multi_process_test(
403+
callable=sharding_single_rank_test,
404+
world_size=world_size,
405+
local_size=local_size,
406+
model_class=model_class,
407+
tables=self.tables,
408+
embedding_groups=self.embedding_groups,
409+
sharders=sharders,
410+
optim=EmbOptimType.EXACT_SGD,
411+
backend=backend,
412+
constraints=constraints,
413+
qcomms_config=qcomms_config,
414+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
415+
variable_batch_size=variable_batch_size,
416+
variable_batch_per_feature=variable_batch_per_feature,
417+
global_constant_batch=True,
418+
input_type="td",
419+
)

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(
160160

161161
tables = [
162162
EmbeddingBagConfig(
163-
num_embeddings=(i + 1) * 1000,
163+
num_embeddings=max(i + 1, 100) * 1000,
164164
embedding_dim=dim_emb,
165165
name="table_" + str(i),
166166
feature_names=["feature_" + str(i)],
@@ -169,7 +169,7 @@ def main(
169169
]
170170
weighted_tables = [
171171
EmbeddingBagConfig(
172-
num_embeddings=(i + 1) * 1000,
172+
num_embeddings=max(i + 1, 100) * 1000,
173173
embedding_dim=dim_emb,
174174
name="weighted_table_" + str(i),
175175
feature_names=["weighted_feature_" + str(i)],

torchrec/modules/embedding_modules.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pooling_type_to_str,
2020
)
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
22+
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
2223

2324

2425
@torch.fx.wrap
@@ -218,7 +219,10 @@ def __init__(
218219
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
219220
self.reset_parameters()
220221

221-
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
222+
def forward(
223+
self,
224+
features: KeyedJaggedTensor, # can also take TensorDict as input
225+
) -> KeyedTensor:
222226
"""
223227
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
224228
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
@@ -229,6 +233,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
229233
KeyedTensor
230234
"""
231235
flat_feature_names: List[str] = []
236+
features = maybe_td_to_kjt(features, None)
232237
for names in self._feature_names:
233238
flat_feature_names.extend(names)
234239
inverse_indices = reorder_inverse_indices(
@@ -448,7 +453,7 @@ def __init__( # noqa C901
448453

449454
def forward(
450455
self,
451-
features: KeyedJaggedTensor,
456+
features: KeyedJaggedTensor, # can also take TensorDict as input
452457
) -> Dict[str, JaggedTensor]:
453458
"""
454459
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
@@ -461,6 +466,7 @@ def forward(
461466
Dict[str, JaggedTensor]
462467
"""
463468

469+
features = maybe_td_to_kjt(features, None)
464470
feature_embeddings: Dict[str, JaggedTensor] = {}
465471
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
466472
for i, emb_module in enumerate(self.embeddings.values()):

0 commit comments

Comments
 (0)