@@ -147,6 +147,7 @@ def gen_model_and_input(
147
147
long_indices : bool = True ,
148
148
global_constant_batch : bool = False ,
149
149
num_inputs : int = 1 ,
150
+ input_type : str = "kjt" , # "kjt" or "td"
150
151
) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
151
152
torch .manual_seed (0 )
152
153
if dedup_feature_names :
@@ -177,9 +178,9 @@ def gen_model_and_input(
177
178
feature_processor_modules = feature_processor_modules ,
178
179
)
179
180
inputs = []
180
- for _ in range ( num_inputs ) :
181
- inputs . append (
182
- (
181
+ if input_type == "kjt" and generate == ModelInput . generate_variable_batch_input :
182
+ for _ in range ( num_inputs ):
183
+ inputs . append (
183
184
cast (VariableBatchModelInputCallable , generate )(
184
185
average_batch_size = batch_size ,
185
186
world_size = world_size ,
@@ -188,8 +189,26 @@ def gen_model_and_input(
188
189
weighted_tables = weighted_tables or [],
189
190
global_constant_batch = global_constant_batch ,
190
191
)
191
- if generate == ModelInput .generate_variable_batch_input
192
- else cast (ModelInputCallable , generate )(
192
+ )
193
+ elif generate == ModelInput .generate :
194
+ for _ in range (num_inputs ):
195
+ inputs .append (
196
+ ModelInput .generate (
197
+ world_size = world_size ,
198
+ tables = tables ,
199
+ dedup_tables = dedup_tables ,
200
+ weighted_tables = weighted_tables or [],
201
+ num_float_features = num_float_features ,
202
+ variable_batch_size = variable_batch_size ,
203
+ batch_size = batch_size ,
204
+ long_indices = long_indices ,
205
+ input_type = input_type ,
206
+ )
207
+ )
208
+ else :
209
+ for _ in range (num_inputs ):
210
+ inputs .append (
211
+ cast (ModelInputCallable , generate )(
193
212
world_size = world_size ,
194
213
tables = tables ,
195
214
dedup_tables = dedup_tables ,
@@ -200,7 +219,6 @@ def gen_model_and_input(
200
219
long_indices = long_indices ,
201
220
)
202
221
)
203
- )
204
222
return (model , inputs )
205
223
206
224
@@ -297,6 +315,7 @@ def sharding_single_rank_test(
297
315
global_constant_batch : bool = False ,
298
316
world_size_2D : Optional [int ] = None ,
299
317
node_group_size : Optional [int ] = None ,
318
+ input_type : str = "kjt" , # "kjt" or "td"
300
319
) -> None :
301
320
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
302
321
# Generate model & inputs.
@@ -319,6 +338,7 @@ def sharding_single_rank_test(
319
338
batch_size = batch_size ,
320
339
feature_processor_modules = feature_processor_modules ,
321
340
global_constant_batch = global_constant_batch ,
341
+ input_type = input_type ,
322
342
)
323
343
global_model = global_model .to (ctx .device )
324
344
global_input = inputs [0 ][0 ].to (ctx .device )
0 commit comments