Skip to content

Commit

Permalink
dataloader profiling format update (pytorch#509)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#509

as title

Reviewed By: ananthsub, wat3rBro

Differential Revision: D48254577

fbshipit-source-id: 2337c30d43c337bedc1779f6b4bd0303c626375c
  • Loading branch information
ninginthecloud authored and facebook-github-bot committed Aug 23, 2023
1 parent 43555dd commit 87c9f53
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pyre_extensions
typing_extensions
setuptools
tqdm
tabulate
17 changes: 17 additions & 0 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BoundedTimer,
FullSyncPeriodicTimer,
get_durations_histogram,
get_recorded_durations_table,
get_synced_durations_histogram,
get_timer_summary,
log_elapsed_time,
Expand Down Expand Up @@ -251,6 +252,22 @@ def test_timer_fn(self) -> None:
with log_elapsed_time("test"):
pass

def test_get_recorded_durations_table(self) -> None:
# empty input
empty_input = get_recorded_durations_table({})
assert empty_input == ""

# no recorded duration values
no_recorded_duration_input = get_recorded_durations_table({"op": {}})
assert no_recorded_duration_input == ""

# valid input
valid_input = get_recorded_durations_table({"op": {"p50": 1, "p90": 2}})
assert (
valid_input
== "\n| Name | p50 | p90 |\n|:-------|------:|------:|\n| op | 1 | 2 |"
)


class FullSyncPeriodicTimerTest(unittest.TestCase):
@classmethod
Expand Down
23 changes: 23 additions & 0 deletions torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch
import torch.distributed as dist
from tabulate import tabulate
from torchtnt.utils.distributed import PGWrapper

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -415,6 +416,28 @@ def _validate_percentiles(percentiles: Sequence[float]) -> None:
raise ValueError(f"Percentile must be between 0 and 100. Got {p}")


def get_recorded_durations_table(result: Dict[str, Dict[str, float]]) -> str:
r"""
Helper function to generate recorded duration time in tabular format
"""
if len(result) == 0:
return ""
sub_dict = next(iter(result.values()))
if len(sub_dict) == 0:
return ""
column_headers = ["Name"] + list(sub_dict.keys())
row_output = []
for key in result:
row = [key] + ["{:.3f}".format(x) for x in result[key].values()]
row_output.append(row)
tabulate_output = tabulate(
row_output,
tablefmt="pipe",
headers=column_headers,
)
return "\n" + tabulate_output


class FullSyncPeriodicTimer:
"""
Measures time (resets if given interval elapses) on rank 0
Expand Down

0 comments on commit 87c9f53

Please sign in to comment.