Skip to content

Commit

Permalink
Fix passing sequence length and batch size to build the model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 423714416
  • Loading branch information
renjie-liu authored and tensorflower-gardener committed Jan 24, 2022
1 parent ab8b801 commit 45848c4
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion official/projects/edgetpu/nlp/serving/export_tflite_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def main(argv: Sequence[str]) -> None:
checkpoint = tf.train.Checkpoint(**checkpoint_dict)
checkpoint.restore(FLAGS.model_checkpoint).assert_existing_objects_matched()

model_for_serving = build_model_for_serving(model)
model_for_serving = build_model_for_serving(model, FLAGS.sequence_length,
FLAGS.batch_size)
model_for_serving.summary()

# TODO(b/194449109): Need to save the model to file and then convert tflite
Expand Down

0 comments on commit 45848c4

Please sign in to comment.