Skip to content

Commit

Permalink
Add FLOPs computation into run_experiment.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 385674527
  • Loading branch information
fyangf authored and tensorflower-gardener committed Jul 20, 2021
1 parent f93bea8 commit 4bd2888
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
5 changes: 5 additions & 0 deletions official/core/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def timeout_fn():
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)

flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)

if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
Expand Down
50 changes: 49 additions & 1 deletion official/core/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from absl import logging
import dataclasses
import gin
import orbit
import tensorflow as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
Expand Down Expand Up @@ -393,3 +396,48 @@ def try_count_params(model: tf.keras.Model):
'train step already reached before this run.')
return None
return None


def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)

# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None

0 comments on commit 4bd2888

Please sign in to comment.