Skip to content

Commit

Permalink
GPU eval during training (fast version)
Browse files Browse the repository at this point in the history
Summary:
I'm making a slight but I think totally fine approximation here. I'm sampling the eval edges *with replacement*, but still swapping in their replacements; therefore there's still no overlap between train/eval, but there could be (rare) uniform duplicates in eval.
The upshot is that it's easy to make this very fast.

Reviewed By: lw

Differential Revision: D20464628

fbshipit-source-id: 93ec91c387424d24816df6fdc7f78dfefbac3009
  • Loading branch information
adamlerer authored and facebook-github-bot committed Mar 18, 2020
1 parent baaf5df commit 98081ea
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions torchbiggraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
)



logger = logging.getLogger("torchbiggraph")
dist_logger = logging.LoggerAdapter(logger, {"distributed": True})

Expand Down Expand Up @@ -634,34 +635,43 @@ def train_and_report_stats(
bucket_logger.debug("Shuffling edges")
# Fix a seed to get the same permutation every time; have it
# depend on all and only what affects the set of edges.
g = torch.Generator()
g.manual_seed(
hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))
)

# Note: for the sake of efficiency, we sample eval edge idxs
# from the edge set *with replacement*, meaning that there may
# be duplicates of the same edge in the eval set. When we swap
# edges into the eval set, if there are duplicates then all
# but one will be clobbered. These collisions are unlikely
# if eval_fraction is small.
#
# Importantly, this eval sampling strategy is theoretically
# sound:
# * Training and eval sets are (exactly) disjoint
# * Eval set may have (rare) duplicates, but they are
# uniformly sampled so it's still an unbiased estimator
# of the out-of-sample statistics
num_eval_edges = int(num_edges * config.eval_fraction)
if num_eval_edges > 0:
edge_perm = torch.randperm(num_edges, generator=g)
eval_edge_perm = edge_perm[-num_eval_edges:]
num_edges -= num_eval_edges
edge_perm = edge_perm[torch.randperm(num_edges)]
g = torch.Generator()
g.manual_seed(
hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))
)
eval_edge_idxs = torch.randint(num_edges, (num_eval_edges,), dtype=torch.long, generator=g)
else:
eval_edge_perm = None
edge_perm = torch.randperm(num_edges)
eval_edge_idxs = None

logger.info(f"Shuffled edges in {time.time() - tic} seconds")

# HOGWILD evaluation before training
eval_stats_before = self._coordinate_eval(edges, eval_edge_perm)
eval_stats_before = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_before is not None:
bucket_logger.info(f"Stats before training: {eval_stats_before}")

# HOGWILD training
bucket_logger.debug("Waiting for workers to perform training")
stats = self._coordinate_train(edges, edge_perm, epoch_idx)
stats = self._coordinate_train(edges, eval_edge_idxs, epoch_idx)

# HOGWILD evaluation after training
eval_stats_after = self._coordinate_eval(edges, eval_edge_perm)
eval_stats_after = self._coordinate_eval(edges, eval_edge_idxs)
if eval_stats_after is not None:
bucket_logger.info(f"Stats before training: {eval_stats_after}")

Expand Down Expand Up @@ -814,9 +824,18 @@ def _swap_partitioned_embeddings(

return io_bytes

def _coordinate_train(self, edges, edge_perm, epoch_idx) -> Stats:
def _coordinate_train(self, edges, eval_edge_idxs, epoch_idx) -> Stats:
assert self.config.num_gpus == 0, "GPU training not supported"

if eval_edge_idxs is not None:
num_train_edges = len(edges) - len(eval_edge_idxs)
train_edge_idxs = torch.arange(len(edges))
train_edge_idxs[eval_edge_idxs] = torch.arange(num_train_edges, len(edges))
train_edge_idxs = train_edge_idxs[:num_train_edges]
edge_perm = train_edge_idxs[torch.randperm(num_train_edges)]
else:
edge_perm = torch.randperm(len(edges))

future_all_stats = self.pool.map_async(call, [
partial(
process_in_batches,
Expand All @@ -836,12 +855,12 @@ def _coordinate_train(self, edges, edge_perm, epoch_idx) -> Stats:
all_stats = get_async_result(future_all_stats, self.pool)
return Stats.sum(all_stats).average()

def _coordinate_eval(self, edges, eval_edge_perm) -> Optional[Stats]:
def _coordinate_eval(self, edges, eval_edge_idxs) -> Optional[Stats]:
eval_batch_size = round_up_to_nearest_multiple(
self.config.batch_size,
self.config.eval_num_batch_negs
)
if eval_edge_perm is not None:
if eval_edge_idxs is not None:
self.bucket_logger.debug("Waiting for workers to perform evaluation")
future_all_eval_stats = self.pool.map_async(call, [
partial(
Expand All @@ -850,9 +869,9 @@ def _coordinate_eval(self, edges, eval_edge_perm) -> Optional[Stats]:
model=self.model,
batch_processor=self.evaluator,
edges=edges,
indices=eval_edge_perm[s],
indices=eval_edge_idxs[s],
)
for s in split_almost_equally(eval_edge_perm.size(0),
for s in split_almost_equally(eval_edge_idxs.size(0),
num_parts=self.num_workers)
])
all_eval_stats = \
Expand Down

0 comments on commit 98081ea

Please sign in to comment.