Skip to content

Commit

Permalink
Add more performance logging
Browse files Browse the repository at this point in the history
Summary:
Fix up some of the broken performance logging.

Also, when run from the command line, the logger didn't print the timestamp, so fixed that.

Reviewed By: lw

Differential Revision: D20464624

fbshipit-source-id: 59b0f28d43f8e3772d5f2ecebc4721eb3a6d182c
  • Loading branch information
adamlerer authored and facebook-github-bot committed Mar 18, 2020
1 parent 98081ea commit c7dea47
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
8 changes: 4 additions & 4 deletions torchbiggraph/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
# FIXME This order assumes higher affinity on the left-hand side, as it's
# the one changing more slowly. Make this adaptive to the actual affinity.
for bucket in create_buckets_ordered_lexicographically(holder.nparts_lhs, holder.nparts_rhs):
tic = time.time()
tic = time.perf_counter()
# logger.info(f"{bucket}: Loading entities")

old_parts = set(holder.partitioned_embeddings.keys())
Expand All @@ -166,8 +166,8 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
num_edges = len(edges)

load_time = time.time() - tic
tic = time.time()
load_time = time.perf_counter() - tic
tic = time.perf_counter()
# logger.info(f"{bucket}: Launching and waiting for workers")
future_all_bucket_stats = pool.map_async(call, [
partial(
Expand All @@ -182,7 +182,7 @@ def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
all_bucket_stats = \
get_async_result(future_all_bucket_stats, pool)

compute_time = time.time() - tic
compute_time = time.perf_counter() - tic
logger.info(
f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
f"({num_edges / compute_time / 1e6:.2g}M/sec); "
Expand Down
12 changes: 6 additions & 6 deletions torchbiggraph/parameter_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,21 +459,21 @@ def _client_thread_loop(
params = {}
clients = [GradientParameterClient(server_rank)
for server_rank in all_server_ranks]
log_time, log_rounds, log_bytes = time.time(), 0, 0
log_time, log_rounds, log_bytes = time.perf_counter(), 0, 0

# thread loop:
# 1. check for a command from the main process
# 2. update (push and pull) each parameter in my list of parameters
# 3. if we're going to fast, sleep for a while
while True:
tic = time.time()
tic = time.perf_counter()
bytes_transferred = 0
try:
data = q.get(timeout=0.01)
cmd, args = data
if cmd == "params":
params[args[0]] = args[1]
log_time, log_rounds, log_bytes = time.time(), 0, 0
log_time, log_rounds, log_bytes = time.perf_counter(), 0, 0
elif cmd == "join":
for client in clients:
client.join()
Expand All @@ -494,15 +494,15 @@ def _client_thread_loop(

log_bytes += bytes_transferred
log_rounds += 1
log_delta = time.time() - log_time
log_delta = time.perf_counter() - log_time
if params and log_delta > 60:
logger.info(
f"Parameter client synced {log_rounds} rounds {log_bytes / 1e9:g} "
f"GB in {log_delta:g} s ({log_delta / log_rounds:g} s/round, "
f"{log_bytes / log_delta / 1e9:g} GB/s)")
log_time, log_rounds, log_bytes = time.time(), 0, 0
log_time, log_rounds, log_bytes = time.perf_counter(), 0, 0

comm_time = time.time() - tic
comm_time = time.perf_counter() - tic
sleep_time = max(bytes_transferred / max_bandwidth - comm_time,
min_sleep_time)
time.sleep(sleep_time)
Expand Down
36 changes: 22 additions & 14 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,27 +590,29 @@ def train_and_report_stats(
while remaining > 0:
old_b: Optional[Bucket] = cur_b
old_stats: Optional[BucketStats] = cur_stats
io_time = 0.
io_bytes = 0
cur_b, remaining = self.bucket_scheduler.acquire_bucket()
logger.info(f"still in queue: {remaining}")
if cur_b is None:
cur_stats = None
if old_b is not None:
# if you couldn't get a new pair, release the lock
# to prevent a deadlock!
tic = time.time()
io_bytes += self._swap_partitioned_embeddings(old_b, None, old_stats)
io_time += time.time() - tic
tic = time.perf_counter()
release_bytes = self._swap_partitioned_embeddings(old_b, None, old_stats)
release_time = time.perf_counter() - tic
logger.info(
f"Swapping old embeddings to release lock. io: {release_time:.2f} s for {release_bytes:,} bytes "
f"( {release_bytes / release_time / 1e6:.2f} MB/sec )"
)
time.sleep(1) # don't hammer td
continue

tic = time.time()
tic = time.perf_counter()
self.cur_b = cur_b
bucket_logger = BucketLogger(logger, bucket=cur_b)
self.bucket_logger = bucket_logger

io_bytes += self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
io_bytes = self._swap_partitioned_embeddings(old_b, cur_b, old_stats)
self.model.set_all_embeddings(holder, cur_b)

current_index = \
Expand All @@ -630,8 +632,8 @@ def train_and_report_stats(
io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size()
io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size()
io_bytes += edges.rel.numel() * edges.rel.element_size()
io_time += time.time() - tic
tic = time.time()
io_time = time.perf_counter() - tic
tic = time.perf_counter()
bucket_logger.debug("Shuffling edges")
# Fix a seed to get the same permutation every time; have it
# depend on all and only what affects the set of edges.
Expand All @@ -650,6 +652,7 @@ def train_and_report_stats(
# uniformly sampled so it's still an unbiased estimator
# of the out-of-sample statistics
num_eval_edges = int(num_edges * config.eval_fraction)
num_train_edges = num_edges - num_eval_edges
if num_eval_edges > 0:
g = torch.Generator()
g.manual_seed(
Expand All @@ -659,29 +662,34 @@ def train_and_report_stats(
else:
eval_edge_idxs = None

logger.info(f"Shuffled edges in {time.time() - tic} seconds")

# HOGWILD evaluation before training
eval_stats_before = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_before is not None:
bucket_logger.info(f"Stats before training: {eval_stats_before}")
eval_time = time.perf_counter() - tic
tic = time.perf_counter()

# HOGWILD training
bucket_logger.debug("Waiting for workers to perform training")
stats = self._coordinate_train(edges, eval_edge_idxs, epoch_idx)
train_time = time.perf_counter() - tic
tic = time.perf_counter()

# HOGWILD evaluation after training
eval_stats_after = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_after is not None:
bucket_logger.info(f"Stats before training: {eval_stats_after}")

compute_time = time.time() - tic
eval_time += time.perf_counter() - tic

bucket_logger.info(
f"bucket {total_buckets - remaining} / {total_buckets} : "
f"Processed {num_edges} edges in {compute_time:.2f} s "
f"( {num_edges / compute_time / 1e6:.2g} M/sec ); "
f"io: {io_time:.2f} s ( {io_bytes / io_time / 1e6:.2f} MB/sec )")
f"Trained {num_train_edges} edges in {train_time:.2f} s "
f"( {num_train_edges / train_time / 1e6:.2g} M/sec ); "
f"Eval 2*{num_eval_edges} edges in {eval_time:.2f} s "
f"( {2 * num_eval_edges / eval_time / 1e6:.2g} M/sec ); "
f"io: {io_time:.2f} s for {io_bytes:,} bytes ( {io_bytes / io_time / 1e6:.2f} MB/sec )")
bucket_logger.info(f"{stats}")

self.model.clear_all_embeddings()
Expand Down

0 comments on commit c7dea47

Please sign in to comment.