Skip to content

Commit

Permalink
Standardized logging code to all global var; writers now accessible f…
Browse files Browse the repository at this point in the history
…rom all files (autonomousvision#62)

* added skeleton for terminal logger

* done skeleton terminal writer logic... todo: actual calls to the write classes

* moved lambda to top function

* added logic to support timing

* prettier printing

* prettier printing

* lint

* separated out timing and other stats variables. printing in dict format

* keeping logging history and wiping rest

* updated logic to calculate total times instead of average times

* made stats tracker configurable

* moved stats tracker and profiler to own class file; implemented basic functionality for profiler

* lint removing duplicate function

* update readme; moved around config variables

* profiler prints on sig kill

* merge

* standardized all logging to global variables
  • Loading branch information
evonneng authored May 7, 2022
1 parent 5cb00c2 commit a498a6c
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 67 deletions.
4 changes: 2 additions & 2 deletions configs/logging/default_logging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ stats_tracker:
stats_to_track: [ITER_LOAD_TIME, ITER_TRAIN_TIME, RAYS_PER_SEC, CURR_TEST_PSNR] # see mattport/utils/stats_tracker.py for options
# writer logs losses and images per iteration
writer:
type: TensorboardWriter
save_dir: './'
TensorboardWriter:
save_dir: './'
# profiler logs run times of functions and prints at end of training
enable_profiler: True
70 changes: 47 additions & 23 deletions mattport/nerf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader

import mattport.utils.writer
from mattport.nerf.dataset.collate import CollateIterDataset, collate_batch_size_one
from mattport.nerf.dataset.image_dataset import ImageDataset, collate_batch
from mattport.nerf.dataset.utils import DatasetInputs, get_dataset_inputs_dict
from mattport.nerf.metrics import get_psnr
from mattport.nerf.optimizers import Optimizers
from mattport.utils import profiler
from mattport.utils import profiler, stats_tracker, writer
from mattport.utils.decorators import check_main_thread
from mattport.utils.stats_tracker import Stats, StatsTracker

logging.getLogger("PIL").setLevel(logging.WARNING)

Expand Down Expand Up @@ -51,12 +49,9 @@ def __init__(self, config: DictConfig, local_rank: int = 0, world_size: int = 1)
self.optimizers = None
self.start_step = 0
# logging variables
self.is_main_thread = world_size != 0 and local_rank % world_size == 0
writer = getattr(mattport.utils.writer, self.config.logging.writer.type)
self.writer = writer(self.is_main_thread, save_dir=self.config.logging.writer.save_dir)
self.stats = StatsTracker(config, self.is_main_thread)
if not profiler.PROFILER and self.config.logging.enable_profiler:
profiler.PROFILER = profiler.Profiler(config, self.is_main_thread)
stats_tracker.setup_stats_tracker(config)
writer.setup_event_writers(config.logging.writer)
profiler.setup_profiler(config.logging)
self.device = "cpu" if self.world_size == 0 else f"cuda:{self.local_rank}"

@profiler.time_function
Expand Down Expand Up @@ -187,27 +182,44 @@ def train(self) -> None:
for i, step in enumerate(range(self.start_step, self.start_step + num_iterations)):
data_start = time()
batch = next(iter_dataset)
self.stats.update_time(Stats.ITER_LOAD_TIME, data_start, time(), step=step)
stats_tracker.update_stats(
{"name": stats_tracker.Stats.ITER_LOAD_TIME, "start_time": data_start, "end_time": time(), "step": step}
)

iter_start = time()
loss_dict = self.train_iteration(batch, step)
self.stats.update_time(
Stats.RAYS_PER_SEC, iter_start, time(), step=step, batch_size=batch["indices"].shape[0]
stats_tracker.update_stats(
{
"name": stats_tracker.Stats.RAYS_PER_SEC,
"start_time": iter_start,
"end_time": time(),
"step": step,
"batch_size": batch["indices"].shape[0],
},
)
stats_tracker.update_stats(
{
"name": stats_tracker.Stats.ITER_TRAIN_TIME,
"start_time": iter_start,
"end_time": time(),
"step": step,
}
)
self.stats.update_time(Stats.ITER_TRAIN_TIME, iter_start, time(), step=step)

if step != 0 and step % self.config.logging.steps_per_log == 0:
self.writer.write_scalar_dict(loss_dict, step, group="Loss", prefix="train-")
writer.write_event({"scalar_dict": loss_dict, "step": step, "group": "Loss", "prefix": "train-"})
# TODO: add the learning rates to tensorboard/logging
if step != 0 and self.config.graph.steps_per_save and step % self.config.graph.steps_per_save == 0:
self.save_checkpoint(self.config.graph.model_dir, step)
if step % self.config.graph.steps_per_test == 0:
for image_idx in self.config.data.validation_image_indices:
self.test_image(image_idx=image_idx, step=step)
self.stats.print_stats(i / num_iterations)
stats_tracker.print_stats(i / num_iterations)

self.stats.update_time(Stats.TOTAL_TRAIN_TIME, train_start, time(), step=-1) # NOTE(ethan): why is step -1?
self.stats.print_stats(-1)
stats_tracker.update_stats(
{"name": stats_tracker.Stats.TOTAL_TRAIN_TIME, "start_time": train_start, "end_time": time()}
)
stats_tracker.print_stats(1.0)

@profiler.time_function
def train_iteration(self, batch: dict, step: int):
Expand Down Expand Up @@ -266,17 +278,29 @@ def test_image(self, image_idx, step):
disparity_fine = torch.cat(disparity_fine).view(image_height, image_width, 1).detach().cpu()

combined_image = torch.cat([image, rgb_coarse, rgb_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}-rgb_coarse_fine", combined_image, step, group="val_img")
writer.write_event(
{"name": f"image_idx_{image_idx}-rgb_coarse_fine", "x": combined_image, "step": step, "group": "val_img"}
)

combined_image = torch.cat([accumulation_coarse, accumulation_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}", combined_image, step, group="val_accumulation")
writer.write_event(
{"name": f"image_idx_{image_idx}", "x": combined_image, "step": step, "group": "val_accumulation"}
)

combined_image = torch.cat([disparity_coarse, disparity_fine], dim=1)
self.writer.write_image(f"image_idx_{image_idx}", combined_image, step, group="val_disparity")
writer.write_event(
{"name": f"image_idx_{image_idx}", "x": combined_image, "step": step, "group": "val_disparity"}
)

coarse_psnr = get_psnr(image, rgb_coarse)
self.writer.write_scalar(f"image_idx_{image_idx}", float(coarse_psnr), step, group="val")
writer.write_event(
{"name": f"image_idx_{image_idx}", "scalar": float(coarse_psnr), "step": step, "group": "val"}
)

fine_psnr = get_psnr(image, rgb_fine)
self.stats.update_value(Stats.CURR_TEST_PSNR, float(fine_psnr), step)
self.writer.write_scalar(f"image_idx_{image_idx}-fine_psnr", float(fine_psnr), step, group="val")
stats_tracker.update_stats(
{"name": stats_tracker.Stats.CURR_TEST_PSNR, "value": float(fine_psnr), "step": step}
)
writer.write_event(
{"name": f"image_idx_{image_idx}-fine_psnr", "scalar": float(fine_psnr), "step": step, "group": "val"}
)
72 changes: 72 additions & 0 deletions mattport/utils/comms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""functionality to handle multiprocessing syncing and communicating"""
import torch
import torch.distributed as dist

_LOCAL_PROCESS_GROUP = None


def get_world_size() -> int:
"""Get total number of available gpus"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()


def get_rank() -> int:
"""Get global rank of current thread"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()


def get_local_rank() -> int:
"""The rank of the current process within the local (per-machine) process group."""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert (
_LOCAL_PROCESS_GROUP is not None
), "Local process group is not created! Please use launch() to spawn processes!"
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)


def get_local_size() -> int:
"""
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)


def is_main_process() -> bool:
"""check to see if you are currently on the main process"""
return get_rank() == 0


def synchronize(world_size):
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
if dist.get_backend() == dist.Backend.NCCL:
# This argument is needed to avoid warnings.
# It's valid only for NCCL backend.
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
6 changes: 4 additions & 2 deletions mattport/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from typing import Callable, List

from mattport.utils import comms


def decorate_all(decorators: List[Callable]) -> Callable:
"""A decorator to decorate all member functions of a class"""
Expand Down Expand Up @@ -34,7 +36,7 @@ def check_profiler_enabled(func: Callable) -> Callable:

def wrapper(self, *args, **kwargs):
ret = None
if self.config.logging.enable_stats:
if self.config.enable_stats:
ret = func(self, *args, **kwargs)
return ret

Expand All @@ -46,7 +48,7 @@ def check_main_thread(func: Callable) -> Callable:

def wrapper(self, *args, **kwargs):
ret = None
if self.is_main_thread:
if comms.is_main_process():
ret = func(self, *args, **kwargs)
return ret

Expand Down
27 changes: 21 additions & 6 deletions mattport/utils/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import logging
import sys
import time
from typing import Callable
from typing import Callable, Dict

from omegaconf import DictConfig

from mattport.utils import comms
from mattport.utils.decorators import check_main_thread, check_profiler_enabled, decorate_all

PROFILER = None
PROFILER = []


def time_function(func: Callable) -> Callable:
Expand All @@ -25,20 +26,34 @@ def wrapper(*args, **kwargs):
for attr in func.__qualname__.split(".")[:-1]:
class_str += f"{vals[attr].__qualname__}_"
class_str += func.__name__
PROFILER.update_time(class_str, start, time.time())
PROFILER[0].update_time(class_str, start, time.time())
return ret

return wrapper


def flush_profiler(config: Dict):
"""Method that checks if profiler is enabled before flushing
Args:
config (Dict): config check
"""
if config.enable_profiler and PROFILER:
PROFILER[0].print_profile()


def setup_profiler(config: DictConfig):
"""Initialization of profilers"""
if comms.is_main_process():
PROFILER.append(Profiler(config))


@decorate_all([check_profiler_enabled, check_main_thread])
class Profiler:
"""Profiler class"""

def __init__(self, config: DictConfig, is_main_thread: bool):
def __init__(self, config: DictConfig):
self.config = config
if self.config.logging.enable_profiler:
self.is_main_thread = is_main_thread
self.profiler_dict = {}

def update_time(self, func_name: str, start_time: float, end_time: float):
Expand Down
46 changes: 38 additions & 8 deletions mattport/utils/stats_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,38 @@
import datetime
import enum
import logging
from typing import Any, Dict

from omegaconf import DictConfig

from mattport.utils import comms
from mattport.utils.decorators import check_main_thread, check_print_stats_step, check_stats_enabled, decorate_all

STATS_TRACKER = []


def update_stats(args: Dict[str, Any]):
"""update global stats tracker based on arguments"""
assert comms.is_main_process(), "Writing out with process other than main"
if "value" in args:
STATS_TRACKER[0].update_value(**args)
elif "start_time" in args:
STATS_TRACKER[0].update_time(**args)
else:
raise NotImplementedError


def print_stats(fraction_done: float):
"""print current statistics information"""
assert comms.is_main_process(), "Writing out with process other than main"
STATS_TRACKER[0].print_stats(fraction_done)


def setup_stats_tracker(config: DictConfig):
"""Initialization of stats tracker"""
if comms.is_main_process():
STATS_TRACKER.append(StatsTracker(config))


class Stats(enum.Enum):
"""Possible Stats values for StatsTracker
Expand All @@ -28,9 +56,8 @@ class Stats(enum.Enum):
class StatsTracker:
"""Stats Tracker class"""

def __init__(self, config: DictConfig, is_main_thread: bool):
def __init__(self, config: DictConfig):
self.config = config
self.is_main_thread = is_main_thread
if self.config.logging.enable_stats:
self.max_history = self.config.logging.stats_tracker.max_history
self.step = 0
Expand All @@ -57,14 +84,17 @@ def update_value(self, name: enum.Enum, value: float, step: int):
self.new_key = not name in self.stats_dict or self.new_key
self.stats_dict[name] = value

def update_time(self, name: enum.Enum, start_time: float, end_time: float, step: int, batch_size: int = None):
def update_time(
self, name: enum.Enum, start_time: float, end_time: float, step: int = None, batch_size: int = None
):
"""update the stats dictionary with running averages/cumulative durations
Args:
name (enum.Enum): Enum name of statistic we are logging
start_time (float): start time for the call in seconds
end_time (float): end time when the call finished executing in seconds
step (int): number of total iteration steps.
step (int): number of total iteration steps. Defaults to None.
if None, reports duration without averaging
batch_size (int, optional): total number of rays in a batch;
if None, reports duration instead of batch per second. Defaults to None.
"""
Expand All @@ -76,12 +106,12 @@ def update_time(self, name: enum.Enum, start_time: float, end_time: float, step:
# calculate the batch per second stat
val = batch_size / val

if step == -1:
# logging total time instead of average
self.stats_dict[name] = val
else:
if step:
# calculate updated average
self.stats_dict[name] = (self.stats_dict.get(name, 0) * step + val) / (step + 1)
else:
# logging total time instead of average
self.stats_dict[name] = val

if name == Stats.ITER_TRAIN_TIME and Stats.ETA in self.stats_to_track:
# update ETA if logging iteration train time
Expand Down
Loading

0 comments on commit a498a6c

Please sign in to comment.