Skip to content

Commit

Permalink
[ranking] Passing experimental_prefetch_to_device=False option for …
Browse files Browse the repository at this point in the history
…distributed dataset of Ranking models.

PiperOrigin-RevId: 384792951
  • Loading branch information
tensorflower-gardener committed Jul 14, 2021
1 parent 0650ea2 commit 2a185c1
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions official/recommendation/ranking/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from absl import flags
from absl import logging

import orbit
import tensorflow as tf

from official.common import distribute_utils
Expand Down Expand Up @@ -95,6 +94,21 @@ def main(_) -> None:
with strategy.scope():
model = task.build_model()

def get_dataset_fn(params):
return lambda input_context: task.build_inputs(params, input_context)

train_dataset = None
if 'train' in mode:
train_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.train_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

validation_dataset = None
if 'eval' in mode:
validation_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.validation_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))

if params.trainer.use_orbit:
with strategy.scope():
checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
Expand All @@ -106,6 +120,8 @@ def main(_) -> None:
optimizer=model.optimizer,
train='train' in mode,
evaluate='eval' in mode,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
checkpoint_exporter=checkpoint_exporter)

train_lib.run_experiment(
Expand All @@ -117,16 +133,6 @@ def main(_) -> None:
trainer=trainer)

else: # Compile/fit
train_dataset = None
if 'train' in mode:
train_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.train_data)

eval_dataset = None
if 'eval' in mode:
eval_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.validation_data)

checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)

latest_checkpoint = tf.train.latest_checkpoint(model_dir)
Expand Down Expand Up @@ -169,15 +175,15 @@ def main(_) -> None:
initial_epoch=initial_epoch,
epochs=num_epochs,
steps_per_epoch=params.trainer.validation_interval,
validation_data=eval_dataset,
validation_data=validation_dataset,
validation_steps=eval_steps,
callbacks=callbacks,
)
model.summary()
logging.info('Train history: %s', history.history)
elif mode == 'eval':
logging.info('Evaluation started')
validation_output = model.evaluate(eval_dataset, steps=eval_steps)
validation_output = model.evaluate(validation_dataset, steps=eval_steps)
logging.info('Evaluation output: %s', validation_output)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
Expand Down

0 comments on commit 2a185c1

Please sign in to comment.